diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 56f53cc16a983..f6152b232a938 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -33,7 +33,10 @@ from superset.common.query_actions import get_query_results from superset.common.utils import dataframe_utils from superset.common.utils.query_cache_manager import QueryCacheManager -from superset.common.utils.time_range_utils import get_since_until_from_query_object +from superset.common.utils.time_range_utils import ( + get_since_until_from_query_object, + get_since_until_from_time_range, +) from superset.connectors.base.models import BaseDatasource from superset.constants import CacheRegion, TimeGrain from superset.daos.annotation import AnnotationLayerDAO @@ -64,6 +67,7 @@ from superset.utils.date_parser import get_past_or_future, normalize_time_delta from superset.utils.pandas_postprocessing.utils import unescape_separator from superset.views.utils import get_viz +from superset.viz import viz_types if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -685,22 +689,53 @@ def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]: def get_viz_annotation_data( annotation_layer: dict[str, Any], force: bool ) -> dict[str, Any]: - chart = ChartDAO.find_by_id(annotation_layer["value"]) - if not chart: + # pylint: disable=import-outside-toplevel,superfluous-parens + from superset.charts.data.commands.get_data_command import ChartDataCommand + + if not (chart := ChartDAO.find_by_id(annotation_layer["value"])): raise QueryObjectValidationError(_("The chart does not exist")) - if not chart.datasource: - raise QueryObjectValidationError(_("The chart datasource does not exist")) - form_data = chart.form_data.copy() - form_data.update(annotation_layer.get("overrides", {})) + try: - viz_obj = get_viz( - datasource_type=chart.datasource.type, - datasource_id=chart.datasource.id, - form_data=form_data, - force=force, - ) - payload = viz_obj.get_payload() - return payload["data"] + if chart.viz_type in viz_types: + if not chart.datasource: + raise QueryObjectValidationError( + _("The chart datasource does not exist"), + ) + + form_data = chart.form_data.copy() + form_data.update(annotation_layer.get("overrides", {})) + + payload = get_viz( + datasource_type=chart.datasource.type, + datasource_id=chart.datasource.id, + form_data=form_data, + force=force, + ).get_payload() + + return payload["data"] + + if not (query_context := chart.get_query_context()): + raise QueryObjectValidationError( + _("The chart query context does not exist"), + ) + + if overrides := annotation_layer.get("overrides"): + if time_grain_sqla := overrides.get("time_grain_sqla"): + for query_object in query_context.queries: + query_object.extras["time_grain_sqla"] = time_grain_sqla + + if time_range := overrides.get("time_range"): + from_dttm, to_dttm = get_since_until_from_time_range(time_range) + + for query_object in query_context.queries: + query_object.from_dttm = from_dttm + query_object.to_dttm = to_dttm + + query_context.force = force + command = ChartDataCommand(query_context) + command.validate() + payload = command.run() + return {"records": payload["queries"][0]["data"]} except SupersetException as ex: raise QueryObjectValidationError(error_msg_from_exception(ex)) from ex diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index e91640b19029f..c9e38f1686cae 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -424,6 +424,16 @@ def create_slices(tbl: SqlaTable) -> tuple[list[Slice], list[Slice]]: viz_type="table", metrics=metrics, ), + query_context=get_slice_json( + default_query_context, + queries=[ + { + "columns": ["ds"], + "metrics": metrics, + "time_range": "1983 : 2023", + } + ], + ), ), Slice( **slice_kwargs, diff --git a/superset/viz.py b/superset/viz.py index 4e39ae2a19f26..3051f104e20af 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -75,10 +75,8 @@ get_column_name, get_column_names, get_column_names_from_columns, - get_metric_names, JS_MAX_INTEGER, merge_extra_filters, - QueryMode, simple_filter_to_adhoc, ) from superset.utils.date_parser import get_since_until, parse_past_timedelta @@ -701,158 +699,6 @@ def raise_for_access(self) -> None: security_manager.raise_for_access(viz=self) -class TableViz(BaseViz): - - """A basic html table that is sortable and searchable""" - - viz_type = "table" - verbose_name = _("Table View") - credits = 'a Superset original' - is_timeseries = False - enforce_numerical_metrics = False - - @deprecated(deprecated_in="3.0") - def process_metrics(self) -> None: - """Process form data and store parsed column configs. - 1. Determine query mode based on form_data params. - - Use `query_mode` if it has a valid value - - Set as RAW mode if `all_columns` is set - - Otherwise defaults to AGG mode - 2. Determine output columns based on query mode. - """ - # Verify form data first: if not specifying query mode, then cannot have both - # GROUP BY and RAW COLUMNS. - if ( - not self.form_data.get("query_mode") - and self.form_data.get("all_columns") - and ( - self.form_data.get("groupby") - or self.form_data.get("metrics") - or self.form_data.get("percent_metrics") - ) - ): - raise QueryObjectValidationError( - _( - "You cannot use [Columns] in combination with " - "[Group By]/[Metrics]/[Percentage Metrics]. " - "Please choose one or the other." - ) - ) - - super().process_metrics() - - self.query_mode: QueryMode = QueryMode.get( - self.form_data.get("query_mode") - ) or ( - # infer query mode from the presence of other fields - QueryMode.RAW - if len(self.form_data.get("all_columns") or []) > 0 - else QueryMode.AGGREGATE - ) - - columns: list[str] # output columns sans time and percent_metric column - percent_columns: list[str] = [] # percent columns that needs extra computation - - if self.query_mode == QueryMode.RAW: - columns = get_metric_names(self.form_data.get("all_columns")) - else: - columns = get_column_names(self.groupby) + get_metric_names( - self.form_data.get("metrics") - ) - percent_columns = get_metric_names( - self.form_data.get("percent_metrics") or [] - ) - - self.columns = columns - self.percent_columns = percent_columns - self.is_timeseries = self.should_be_timeseries() - - @deprecated(deprecated_in="3.0") - def should_be_timeseries(self) -> bool: - # TODO handle datasource-type-specific code in datasource - conditions_met = self.form_data.get("granularity_sqla") and self.form_data.get( - "time_grain_sqla" - ) - if self.form_data.get("include_time") and not conditions_met: - raise QueryObjectValidationError( - _("Pick a granularity in the Time section or " "uncheck 'Include Time'") - ) - return bool(self.form_data.get("include_time")) - - @deprecated(deprecated_in="3.0") - def query_obj(self) -> QueryObjectDict: - query_obj = super().query_obj() - if self.query_mode == QueryMode.RAW: - query_obj["columns"] = self.form_data.get("all_columns") - order_by_cols = self.form_data.get("order_by_cols") or [] - query_obj["orderby"] = [json.loads(t) for t in order_by_cols] - # must disable groupby and metrics in raw mode - query_obj["groupby"] = [] - query_obj["metrics"] = [] - # raw mode does not support timeseries queries - query_obj["timeseries_limit_metric"] = None - query_obj["timeseries_limit"] = None - query_obj["is_timeseries"] = None - else: - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: - sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): - query_obj["metrics"].append(sort_by) - query_obj["orderby"] = [ - (sort_by, not self.form_data.get("order_desc", True)) - ] - elif query_obj["metrics"]: - # Legacy behavior of sorting by first metric by default - first_metric = query_obj["metrics"][0] - query_obj["orderby"] = [ - (first_metric, not self.form_data.get("order_desc", True)) - ] - return query_obj - - @deprecated(deprecated_in="3.0") - def get_data(self, df: pd.DataFrame) -> VizData: - """ - Transform the query result to the table representation. - - :param df: The interim dataframe - :returns: The table visualization data - - The interim dataframe comprises of the group-by and non-group-by columns and - the union of the metrics representing the non-percent and percent metrics. Note - the percent metrics have yet to be transformed. - """ - # Transform the data frame to adhere to the UI ordering of the columns and - # metrics whilst simultaneously computing the percentages (via normalization) - # for the percent metrics. - if df.empty: - return None - - columns, percent_columns = self.columns, self.percent_columns - if DTTM_ALIAS in df and self.is_timeseries: - columns = [DTTM_ALIAS] + columns - df = pd.concat( - [ - df[columns], - (df[percent_columns].div(df[percent_columns].sum()).add_prefix("%")), - ], - axis=1, - ) - return self.handle_js_int_overflow( - dict(records=df.to_dict(orient="records"), columns=list(df.columns)) - ) - - @staticmethod - @deprecated(deprecated_in="3.0") - def json_dumps(query_obj: Any, sort_keys: bool = False) -> str: - return json.dumps( - query_obj, - default=utils.json_iso_dttm_ser, - sort_keys=sort_keys, - ignore_nan=True, - ) - - class TimeTableViz(BaseViz): """A data table with rich time-series related columns""" @@ -1076,65 +922,6 @@ def get_data(self, df: pd.DataFrame) -> VizData: } -class BigNumberViz(BaseViz): - - """Put emphasis on a single metric with this big number viz""" - - viz_type = "big_number" - verbose_name = _("Big Number with Trendline") - credits = 'a Superset original' - is_timeseries = True - - @deprecated(deprecated_in="3.0") - def query_obj(self) -> QueryObjectDict: - query_obj = super().query_obj() - metric = self.form_data.get("metric") - if not metric: - raise QueryObjectValidationError(_("Pick a metric!")) - query_obj["metrics"] = [self.form_data.get("metric")] - self.form_data["metric"] = metric - return query_obj - - @deprecated(deprecated_in="3.0") - def get_data(self, df: pd.DataFrame) -> VizData: - if df.empty: - return None - - df = df.pivot_table( - index=DTTM_ALIAS, - columns=[], - values=self.metric_labels, - dropna=False, - aggfunc=np.min, # looking for any (only) value, preserving `None` - ) - df = self.apply_rolling(df) - df[DTTM_ALIAS] = df.index - return super().get_data(df) - - -class BigNumberTotalViz(BaseViz): - - """Put emphasis on a single metric with this big number viz""" - - viz_type = "big_number_total" - verbose_name = _("Big Number") - credits = 'a Superset original' - is_timeseries = False - - @deprecated(deprecated_in="3.0") - def query_obj(self) -> QueryObjectDict: - query_obj = super().query_obj() - metric = self.form_data.get("metric") - if not metric: - raise QueryObjectValidationError(_("Pick a metric!")) - query_obj["metrics"] = [self.form_data.get("metric")] - self.form_data["metric"] = metric - - # Limiting rows is not required as only one cell is returned - query_obj["row_limit"] = None - return query_obj - - class NVD3TimeSeriesViz(NVD3Viz): """A rich line chart component with tons of options""" diff --git a/tests/integration_tests/cache_tests.py b/tests/integration_tests/cache_tests.py index a7da8a50d2a59..b2a8704dfb237 100644 --- a/tests/integration_tests/cache_tests.py +++ b/tests/integration_tests/cache_tests.py @@ -46,7 +46,7 @@ def test_no_data_cache(self): app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "NullCache"} cache_manager.init_app(app) - slc = self.get_slice("Girls", db.session) + slc = self.get_slice("Top 10 Girl Name Share", db.session) json_endpoint = "/superset/explore_json/{}/{}/".format( slc.datasource_type, slc.datasource_id ) @@ -73,7 +73,7 @@ def test_slice_data_cache(self): } cache_manager.init_app(app) - slc = self.get_slice("Boys", db.session) + slc = self.get_slice("Top 10 Girl Name Share", db.session) json_endpoint = "/superset/explore_json/{}/{}/".format( slc.datasource_type, slc.datasource_id ) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index f09662a8de919..734990a1eecc3 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1715,7 +1715,7 @@ def test_gets_owned_created_favorited_by_me_filter(self): ) def test_warm_up_cache(self): self.login() - slc = self.get_slice("Girls", db.session) + slc = self.get_slice("Top 10 Girl Name Share", db.session) rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id}) self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index 911f3bf5daa4c..9580c2bf33e6b 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -456,7 +456,7 @@ def test_warm_up_cache_command_chart_not_found(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache(self): - slc = self.get_slice("Girls", db.session) + slc = self.get_slice("Top 10 Girl Name Share", db.session) result = ChartWarmUpCacheCommand(slc.id, None, None).run() self.assertEqual( result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index c96555503b598..0acb19969a23d 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -173,7 +173,7 @@ def test_slice_endpoint(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_viz_cache_key(self): self.login(username="admin") - slc = self.get_slice("Girls", db.session) + slc = self.get_slice("Top 10 Girl Name Share", db.session) viz = slc.viz qobj = viz.query_obj() @@ -279,7 +279,9 @@ def test_slice_data(self): # slice data should have some required attributes self.login(username="admin") slc = self.get_slice( - slice_name="Girls", session=db.session, expunge_from_session=False + slice_name="Top 10 Girl Name Share", + session=db.session, + expunge_from_session=False, ) slc_data_attributes = slc.data.keys() assert "changed_on" in slc_data_attributes @@ -391,7 +393,7 @@ def test_databaseview_edit(self, username="admin"): ) def test_warm_up_cache(self): self.login() - slc = self.get_slice("Girls", db.session) + slc = self.get_slice("Top 10 Girl Name Share", db.session) data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}") self.assertEqual( data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] @@ -418,10 +420,10 @@ def test_cache_logging(self): self.login("admin") store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True - girls_slice = self.get_slice("Girls", db.session) - self.get_json_resp(f"/superset/warm_up_cache?slice_id={girls_slice.id}") + slc = self.get_slice("Top 10 Girl Name Share", db.session) + self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}") ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() - assert ck.datasource_uid == f"{girls_slice.table.id}__table" + assert ck.datasource_uid == f"{slc.table.id}__table" app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys def test_redirect_invalid(self): diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index fe199443875dd..1518a69f9dbbe 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1680,7 +1680,7 @@ def test_raise_for_access_table(self, mock_can_access): def test_raise_for_access_viz( self, mock_can_access_schema, mock_can_access, mock_is_owner ): - test_viz = viz.TableViz(self.get_datasource_mock(), form_data={}) + test_viz = viz.TimeTableViz(self.get_datasource_mock(), form_data={}) mock_can_access_schema.return_value = True security_manager.raise_for_access(viz=test_viz) diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 2986188ff98ec..c0383d1d0b75d 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -974,7 +974,7 @@ def test_get_form_data_corrupted_json(self) -> None: def test_log_this(self) -> None: # TODO: Add additional scenarios. self.login(username="admin") - slc = self.get_slice("Girls", db.session) + slc = self.get_slice("Top 10 Girl Name Share", db.session) dashboard_id = 1 assert slc.viz is not None diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py index ac390b3976a04..f1665e96888d0 100644 --- a/tests/integration_tests/viz_tests.py +++ b/tests/integration_tests/viz_tests.py @@ -45,7 +45,7 @@ def test_constructor_exception_no_datasource(self): viz.BaseViz(datasource, form_data) def test_process_metrics(self): - # test TableViz metrics in correct order + # test TimeTableViz metrics in correct order form_data = { "url_params": {}, "row_limit": 500, @@ -55,7 +55,7 @@ def test_process_metrics(self): "granularity_sqla": "year", "page_length": 0, "all_columns": [], - "viz_type": "table", + "viz_type": "time_table", "since": "2014-01-01", "until": "2014-01-02", "metrics": ["sum__SP_POP_TOTL", "SUM(SE_PRM_NENR_MA)", "SUM(SP_URB_TOTL)"], @@ -177,273 +177,6 @@ def test_cache_timeout(self): app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout -class TestTableViz(SupersetTestCase): - def test_get_data_applies_percentage(self): - form_data = { - "groupby": ["groupA", "groupB"], - "metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - }, - "count", - "avg__C", - ], - "percent_metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - }, - "avg__B", - ], - } - datasource = self.get_datasource_mock() - - df = pd.DataFrame( - { - "SUM(value1)": [15, 20, 25, 40], - "avg__B": [10, 20, 5, 15], - "avg__C": [11, 22, 33, 44], - "count": [6, 7, 8, 9], - "groupA": ["A", "B", "C", "C"], - "groupB": ["x", "x", "y", "z"], - } - ) - - test_viz = viz.TableViz(datasource, form_data) - data = test_viz.get_data(df) - # Check method correctly transforms data and computes percents - self.assertEqual( - [ - "groupA", - "groupB", - "SUM(value1)", - "count", - "avg__C", - "%SUM(value1)", - "%avg__B", - ], - list(data["columns"]), - ) - expected = [ - { - "groupA": "A", - "groupB": "x", - "SUM(value1)": 15, - "count": 6, - "avg__C": 11, - "%SUM(value1)": 0.15, - "%avg__B": 0.2, - }, - { - "groupA": "B", - "groupB": "x", - "SUM(value1)": 20, - "count": 7, - "avg__C": 22, - "%SUM(value1)": 0.2, - "%avg__B": 0.4, - }, - { - "groupA": "C", - "groupB": "y", - "SUM(value1)": 25, - "count": 8, - "avg__C": 33, - "%SUM(value1)": 0.25, - "%avg__B": 0.1, - }, - { - "groupA": "C", - "groupB": "z", - "SUM(value1)": 40, - "count": 9, - "avg__C": 44, - "%SUM(value1)": 0.4, - "%avg__B": 0.3, - }, - ] - self.assertEqual(expected, data["records"]) - - def test_parse_adhoc_filters(self): - form_data = { - "metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - } - ], - "adhoc_filters": [ - { - "expressionType": "SIMPLE", - "clause": "WHERE", - "subject": "value2", - "operator": ">", - "comparator": "100", - }, - { - "expressionType": "SQL", - "clause": "HAVING", - "sqlExpression": "SUM(value1) > 5", - }, - { - "expressionType": "SQL", - "clause": "WHERE", - "sqlExpression": "value3 in ('North America')", - }, - ], - } - datasource = self.get_datasource_mock() - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual( - [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"] - ) - self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"]) - self.assertEqual("(SUM(value1) > 5)", query_obj["extras"]["having"]) - - def test_adhoc_filters_overwrite_legacy_filters(self): - form_data = { - "metrics": [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - } - ], - "adhoc_filters": [ - { - "expressionType": "SIMPLE", - "clause": "WHERE", - "subject": "value2", - "operator": ">", - "comparator": "100", - }, - { - "expressionType": "SQL", - "clause": "WHERE", - "sqlExpression": "value3 in ('North America')", - }, - ], - "having": "SUM(value1) > 5", - } - datasource = self.get_datasource_mock() - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual( - [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"] - ) - self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"]) - self.assertEqual("", query_obj["extras"]["having"]) - - def test_query_obj_merges_percent_metrics(self): - datasource = self.get_datasource_mock() - form_data = { - "metrics": ["sum__A", "count", "avg__C"], - "percent_metrics": ["sum__A", "avg__B", "max__Y"], - } - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual( - ["sum__A", "count", "avg__C", "avg__B", "max__Y"], query_obj["metrics"] - ) - - def test_query_obj_throws_columns_and_metrics(self): - datasource = self.get_datasource_mock() - form_data = {"all_columns": ["A", "B"], "metrics": ["x", "y"]} - with self.assertRaises(Exception): - test_viz = viz.TableViz(datasource, form_data) - test_viz.query_obj() - del form_data["metrics"] - form_data["groupby"] = ["B", "C"] - with self.assertRaises(Exception): - test_viz = viz.TableViz(datasource, form_data) - test_viz.query_obj() - - @patch("superset.viz.BaseViz.query_obj") - def test_query_obj_merges_all_columns(self, super_query_obj): - datasource = self.get_datasource_mock() - form_data = { - "all_columns": ["colA", "colB", "colC"], - "order_by_cols": ['["colA", "colB"]', '["colC"]'], - } - super_query_obj.return_value = { - "columns": ["colD", "colC"], - "groupby": ["colA", "colB"], - } - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual(form_data["all_columns"], query_obj["columns"]) - self.assertEqual([], query_obj["groupby"]) - self.assertEqual([["colA", "colB"], ["colC"]], query_obj["orderby"]) - - def test_query_obj_uses_sortby(self): - datasource = self.get_datasource_mock() - form_data = { - "metrics": ["colA", "colB"], - "order_desc": False, - } - - def run_test(metric): - form_data["timeseries_limit_metric"] = metric - test_viz = viz.TableViz(datasource, form_data) - query_obj = test_viz.query_obj() - self.assertEqual(["colA", "colB", metric], query_obj["metrics"]) - self.assertEqual([(metric, True)], query_obj["orderby"]) - - run_test("simple_metric") - run_test( - { - "label": "adhoc_metric", - "expressionType": "SIMPLE", - "aggregate": "SUM", - "column": { - "column_name": "sort_column", - }, - } - ) - - def test_should_be_timeseries_raises_when_no_granularity(self): - datasource = self.get_datasource_mock() - form_data = {"include_time": True} - with self.assertRaises(Exception): - test_viz = viz.TableViz(datasource, form_data) - test_viz.should_be_timeseries() - - def test_adhoc_metric_with_sortby(self): - metrics = [ - { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "sum_value", - "column": {"column_name": "value1", "type": "DOUBLE"}, - } - ] - form_data = { - "metrics": metrics, - "timeseries_limit_metric": { - "expressionType": "SIMPLE", - "aggregate": "SUM", - "label": "SUM(value1)", - "column": {"column_name": "value1", "type": "DOUBLE"}, - }, - "order_desc": False, - } - - df = pd.DataFrame({"SUM(value1)": [15], "sum_value": [15]}) - datasource = self.get_datasource_mock() - test_viz = viz.TableViz(datasource, form_data) - data = test_viz.get_data(df) - self.assertEqual(["sum_value"], data["columns"]) - - class TestDistBarViz(SupersetTestCase): def test_groupby_nulls(self): form_data = { @@ -1311,7 +1044,7 @@ def test_apply_rolling(self): data={"y": [1.0, 2.0, 3.0, 4.0]}, ) self.assertEqual( - viz.BigNumberViz( + viz.NVD3TimeSeriesViz( datasource, { "metrics": ["y"], @@ -1325,7 +1058,7 @@ def test_apply_rolling(self): [1.0, 3.0, 6.0, 10.0], ) self.assertEqual( - viz.BigNumberViz( + viz.NVD3TimeSeriesViz( datasource, { "metrics": ["y"], @@ -1339,7 +1072,7 @@ def test_apply_rolling(self): [1.0, 3.0, 5.0, 7.0], ) self.assertEqual( - viz.BigNumberViz( + viz.NVD3TimeSeriesViz( datasource, { "metrics": ["y"], @@ -1361,7 +1094,7 @@ def test_apply_rolling_without_data(self): ), data={"y": [1.0, 2.0, 3.0, 4.0]}, ) - test_viz = viz.BigNumberViz( + test_viz = viz.NVD3TimeSeriesViz( datasource, { "metrics": ["y"], @@ -1374,34 +1107,6 @@ def test_apply_rolling_without_data(self): test_viz.apply_rolling(df) -class TestBigNumberViz(SupersetTestCase): - def test_get_data(self): - datasource = self.get_datasource_mock() - df = pd.DataFrame( - data={ - DTTM_ALIAS: pd.to_datetime( - ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] - ), - "y": [1.0, 2.0, 3.0, 4.0], - } - ) - data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df) - self.assertEqual(data[2], {DTTM_ALIAS: pd.Timestamp("2019-01-05"), "y": 3}) - - def test_get_data_with_none(self): - datasource = self.get_datasource_mock() - df = pd.DataFrame( - data={ - DTTM_ALIAS: pd.to_datetime( - ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] - ), - "y": [1.0, 2.0, None, 4.0], - } - ) - data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df) - assert np.isnan(data[2]["y"]) - - class TestFilterBoxViz(SupersetTestCase): def test_get_data(self): form_data = {