diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 56f53cc16a983..f6152b232a938 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -33,7 +33,10 @@
from superset.common.query_actions import get_query_results
from superset.common.utils import dataframe_utils
from superset.common.utils.query_cache_manager import QueryCacheManager
-from superset.common.utils.time_range_utils import get_since_until_from_query_object
+from superset.common.utils.time_range_utils import (
+ get_since_until_from_query_object,
+ get_since_until_from_time_range,
+)
from superset.connectors.base.models import BaseDatasource
from superset.constants import CacheRegion, TimeGrain
from superset.daos.annotation import AnnotationLayerDAO
@@ -64,6 +67,7 @@
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
from superset.utils.pandas_postprocessing.utils import unescape_separator
from superset.views.utils import get_viz
+from superset.viz import viz_types
if TYPE_CHECKING:
from superset.common.query_context import QueryContext
@@ -685,22 +689,53 @@ def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]:
def get_viz_annotation_data(
annotation_layer: dict[str, Any], force: bool
) -> dict[str, Any]:
- chart = ChartDAO.find_by_id(annotation_layer["value"])
- if not chart:
+ # pylint: disable=import-outside-toplevel,superfluous-parens
+ from superset.charts.data.commands.get_data_command import ChartDataCommand
+
+ if not (chart := ChartDAO.find_by_id(annotation_layer["value"])):
raise QueryObjectValidationError(_("The chart does not exist"))
- if not chart.datasource:
- raise QueryObjectValidationError(_("The chart datasource does not exist"))
- form_data = chart.form_data.copy()
- form_data.update(annotation_layer.get("overrides", {}))
+
try:
- viz_obj = get_viz(
- datasource_type=chart.datasource.type,
- datasource_id=chart.datasource.id,
- form_data=form_data,
- force=force,
- )
- payload = viz_obj.get_payload()
- return payload["data"]
+ if chart.viz_type in viz_types:
+ if not chart.datasource:
+ raise QueryObjectValidationError(
+ _("The chart datasource does not exist"),
+ )
+
+ form_data = chart.form_data.copy()
+ form_data.update(annotation_layer.get("overrides", {}))
+
+ payload = get_viz(
+ datasource_type=chart.datasource.type,
+ datasource_id=chart.datasource.id,
+ form_data=form_data,
+ force=force,
+ ).get_payload()
+
+ return payload["data"]
+
+ if not (query_context := chart.get_query_context()):
+ raise QueryObjectValidationError(
+ _("The chart query context does not exist"),
+ )
+
+ if overrides := annotation_layer.get("overrides"):
+ if time_grain_sqla := overrides.get("time_grain_sqla"):
+ for query_object in query_context.queries:
+ query_object.extras["time_grain_sqla"] = time_grain_sqla
+
+ if time_range := overrides.get("time_range"):
+ from_dttm, to_dttm = get_since_until_from_time_range(time_range)
+
+ for query_object in query_context.queries:
+ query_object.from_dttm = from_dttm
+ query_object.to_dttm = to_dttm
+
+ query_context.force = force
+ command = ChartDataCommand(query_context)
+ command.validate()
+ payload = command.run()
+ return {"records": payload["queries"][0]["data"]}
except SupersetException as ex:
raise QueryObjectValidationError(error_msg_from_exception(ex)) from ex
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index e91640b19029f..c9e38f1686cae 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -424,6 +424,16 @@ def create_slices(tbl: SqlaTable) -> tuple[list[Slice], list[Slice]]:
viz_type="table",
metrics=metrics,
),
+ query_context=get_slice_json(
+ default_query_context,
+ queries=[
+ {
+ "columns": ["ds"],
+ "metrics": metrics,
+ "time_range": "1983 : 2023",
+ }
+ ],
+ ),
),
Slice(
**slice_kwargs,
diff --git a/superset/viz.py b/superset/viz.py
index 4e39ae2a19f26..3051f104e20af 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -75,10 +75,8 @@
get_column_name,
get_column_names,
get_column_names_from_columns,
- get_metric_names,
JS_MAX_INTEGER,
merge_extra_filters,
- QueryMode,
simple_filter_to_adhoc,
)
from superset.utils.date_parser import get_since_until, parse_past_timedelta
@@ -701,158 +699,6 @@ def raise_for_access(self) -> None:
security_manager.raise_for_access(viz=self)
-class TableViz(BaseViz):
-
- """A basic html table that is sortable and searchable"""
-
- viz_type = "table"
- verbose_name = _("Table View")
- credits = 'a Superset original'
- is_timeseries = False
- enforce_numerical_metrics = False
-
- @deprecated(deprecated_in="3.0")
- def process_metrics(self) -> None:
- """Process form data and store parsed column configs.
- 1. Determine query mode based on form_data params.
- - Use `query_mode` if it has a valid value
- - Set as RAW mode if `all_columns` is set
- - Otherwise defaults to AGG mode
- 2. Determine output columns based on query mode.
- """
- # Verify form data first: if not specifying query mode, then cannot have both
- # GROUP BY and RAW COLUMNS.
- if (
- not self.form_data.get("query_mode")
- and self.form_data.get("all_columns")
- and (
- self.form_data.get("groupby")
- or self.form_data.get("metrics")
- or self.form_data.get("percent_metrics")
- )
- ):
- raise QueryObjectValidationError(
- _(
- "You cannot use [Columns] in combination with "
- "[Group By]/[Metrics]/[Percentage Metrics]. "
- "Please choose one or the other."
- )
- )
-
- super().process_metrics()
-
- self.query_mode: QueryMode = QueryMode.get(
- self.form_data.get("query_mode")
- ) or (
- # infer query mode from the presence of other fields
- QueryMode.RAW
- if len(self.form_data.get("all_columns") or []) > 0
- else QueryMode.AGGREGATE
- )
-
- columns: list[str] # output columns sans time and percent_metric column
- percent_columns: list[str] = [] # percent columns that needs extra computation
-
- if self.query_mode == QueryMode.RAW:
- columns = get_metric_names(self.form_data.get("all_columns"))
- else:
- columns = get_column_names(self.groupby) + get_metric_names(
- self.form_data.get("metrics")
- )
- percent_columns = get_metric_names(
- self.form_data.get("percent_metrics") or []
- )
-
- self.columns = columns
- self.percent_columns = percent_columns
- self.is_timeseries = self.should_be_timeseries()
-
- @deprecated(deprecated_in="3.0")
- def should_be_timeseries(self) -> bool:
- # TODO handle datasource-type-specific code in datasource
- conditions_met = self.form_data.get("granularity_sqla") and self.form_data.get(
- "time_grain_sqla"
- )
- if self.form_data.get("include_time") and not conditions_met:
- raise QueryObjectValidationError(
- _("Pick a granularity in the Time section or " "uncheck 'Include Time'")
- )
- return bool(self.form_data.get("include_time"))
-
- @deprecated(deprecated_in="3.0")
- def query_obj(self) -> QueryObjectDict:
- query_obj = super().query_obj()
- if self.query_mode == QueryMode.RAW:
- query_obj["columns"] = self.form_data.get("all_columns")
- order_by_cols = self.form_data.get("order_by_cols") or []
- query_obj["orderby"] = [json.loads(t) for t in order_by_cols]
- # must disable groupby and metrics in raw mode
- query_obj["groupby"] = []
- query_obj["metrics"] = []
- # raw mode does not support timeseries queries
- query_obj["timeseries_limit_metric"] = None
- query_obj["timeseries_limit"] = None
- query_obj["is_timeseries"] = None
- else:
- sort_by = self.form_data.get("timeseries_limit_metric")
- if sort_by:
- sort_by_label = utils.get_metric_name(sort_by)
- if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
- query_obj["metrics"].append(sort_by)
- query_obj["orderby"] = [
- (sort_by, not self.form_data.get("order_desc", True))
- ]
- elif query_obj["metrics"]:
- # Legacy behavior of sorting by first metric by default
- first_metric = query_obj["metrics"][0]
- query_obj["orderby"] = [
- (first_metric, not self.form_data.get("order_desc", True))
- ]
- return query_obj
-
- @deprecated(deprecated_in="3.0")
- def get_data(self, df: pd.DataFrame) -> VizData:
- """
- Transform the query result to the table representation.
-
- :param df: The interim dataframe
- :returns: The table visualization data
-
- The interim dataframe comprises of the group-by and non-group-by columns and
- the union of the metrics representing the non-percent and percent metrics. Note
- the percent metrics have yet to be transformed.
- """
- # Transform the data frame to adhere to the UI ordering of the columns and
- # metrics whilst simultaneously computing the percentages (via normalization)
- # for the percent metrics.
- if df.empty:
- return None
-
- columns, percent_columns = self.columns, self.percent_columns
- if DTTM_ALIAS in df and self.is_timeseries:
- columns = [DTTM_ALIAS] + columns
- df = pd.concat(
- [
- df[columns],
- (df[percent_columns].div(df[percent_columns].sum()).add_prefix("%")),
- ],
- axis=1,
- )
- return self.handle_js_int_overflow(
- dict(records=df.to_dict(orient="records"), columns=list(df.columns))
- )
-
- @staticmethod
- @deprecated(deprecated_in="3.0")
- def json_dumps(query_obj: Any, sort_keys: bool = False) -> str:
- return json.dumps(
- query_obj,
- default=utils.json_iso_dttm_ser,
- sort_keys=sort_keys,
- ignore_nan=True,
- )
-
-
class TimeTableViz(BaseViz):
"""A data table with rich time-series related columns"""
@@ -1076,65 +922,6 @@ def get_data(self, df: pd.DataFrame) -> VizData:
}
-class BigNumberViz(BaseViz):
-
- """Put emphasis on a single metric with this big number viz"""
-
- viz_type = "big_number"
- verbose_name = _("Big Number with Trendline")
- credits = 'a Superset original'
- is_timeseries = True
-
- @deprecated(deprecated_in="3.0")
- def query_obj(self) -> QueryObjectDict:
- query_obj = super().query_obj()
- metric = self.form_data.get("metric")
- if not metric:
- raise QueryObjectValidationError(_("Pick a metric!"))
- query_obj["metrics"] = [self.form_data.get("metric")]
- self.form_data["metric"] = metric
- return query_obj
-
- @deprecated(deprecated_in="3.0")
- def get_data(self, df: pd.DataFrame) -> VizData:
- if df.empty:
- return None
-
- df = df.pivot_table(
- index=DTTM_ALIAS,
- columns=[],
- values=self.metric_labels,
- dropna=False,
- aggfunc=np.min, # looking for any (only) value, preserving `None`
- )
- df = self.apply_rolling(df)
- df[DTTM_ALIAS] = df.index
- return super().get_data(df)
-
-
-class BigNumberTotalViz(BaseViz):
-
- """Put emphasis on a single metric with this big number viz"""
-
- viz_type = "big_number_total"
- verbose_name = _("Big Number")
- credits = 'a Superset original'
- is_timeseries = False
-
- @deprecated(deprecated_in="3.0")
- def query_obj(self) -> QueryObjectDict:
- query_obj = super().query_obj()
- metric = self.form_data.get("metric")
- if not metric:
- raise QueryObjectValidationError(_("Pick a metric!"))
- query_obj["metrics"] = [self.form_data.get("metric")]
- self.form_data["metric"] = metric
-
- # Limiting rows is not required as only one cell is returned
- query_obj["row_limit"] = None
- return query_obj
-
-
class NVD3TimeSeriesViz(NVD3Viz):
"""A rich line chart component with tons of options"""
diff --git a/tests/integration_tests/cache_tests.py b/tests/integration_tests/cache_tests.py
index a7da8a50d2a59..b2a8704dfb237 100644
--- a/tests/integration_tests/cache_tests.py
+++ b/tests/integration_tests/cache_tests.py
@@ -46,7 +46,7 @@ def test_no_data_cache(self):
app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "NullCache"}
cache_manager.init_app(app)
- slc = self.get_slice("Girls", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
json_endpoint = "/superset/explore_json/{}/{}/".format(
slc.datasource_type, slc.datasource_id
)
@@ -73,7 +73,7 @@ def test_slice_data_cache(self):
}
cache_manager.init_app(app)
- slc = self.get_slice("Boys", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
json_endpoint = "/superset/explore_json/{}/{}/".format(
slc.datasource_type, slc.datasource_id
)
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index f09662a8de919..734990a1eecc3 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -1715,7 +1715,7 @@ def test_gets_owned_created_favorited_by_me_filter(self):
)
def test_warm_up_cache(self):
self.login()
- slc = self.get_slice("Girls", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py
index 911f3bf5daa4c..9580c2bf33e6b 100644
--- a/tests/integration_tests/charts/commands_tests.py
+++ b/tests/integration_tests/charts/commands_tests.py
@@ -456,7 +456,7 @@ def test_warm_up_cache_command_chart_not_found(self):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache(self):
- slc = self.get_slice("Girls", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
result = ChartWarmUpCacheCommand(slc.id, None, None).run()
self.assertEqual(
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py
index c96555503b598..0acb19969a23d 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -173,7 +173,7 @@ def test_slice_endpoint(self):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_viz_cache_key(self):
self.login(username="admin")
- slc = self.get_slice("Girls", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
viz = slc.viz
qobj = viz.query_obj()
@@ -279,7 +279,9 @@ def test_slice_data(self):
# slice data should have some required attributes
self.login(username="admin")
slc = self.get_slice(
- slice_name="Girls", session=db.session, expunge_from_session=False
+ slice_name="Top 10 Girl Name Share",
+ session=db.session,
+ expunge_from_session=False,
)
slc_data_attributes = slc.data.keys()
assert "changed_on" in slc_data_attributes
@@ -391,7 +393,7 @@ def test_databaseview_edit(self, username="admin"):
)
def test_warm_up_cache(self):
self.login()
- slc = self.get_slice("Girls", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
self.assertEqual(
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
@@ -418,10 +420,10 @@ def test_cache_logging(self):
self.login("admin")
store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True
- girls_slice = self.get_slice("Girls", db.session)
- self.get_json_resp(f"/superset/warm_up_cache?slice_id={girls_slice.id}")
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
- assert ck.datasource_uid == f"{girls_slice.table.id}__table"
+ assert ck.datasource_uid == f"{slc.table.id}__table"
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys
def test_redirect_invalid(self):
diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py
index fe199443875dd..1518a69f9dbbe 100644
--- a/tests/integration_tests/security_tests.py
+++ b/tests/integration_tests/security_tests.py
@@ -1680,7 +1680,7 @@ def test_raise_for_access_table(self, mock_can_access):
def test_raise_for_access_viz(
self, mock_can_access_schema, mock_can_access, mock_is_owner
):
- test_viz = viz.TableViz(self.get_datasource_mock(), form_data={})
+ test_viz = viz.TimeTableViz(self.get_datasource_mock(), form_data={})
mock_can_access_schema.return_value = True
security_manager.raise_for_access(viz=test_viz)
diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py
index 2986188ff98ec..c0383d1d0b75d 100644
--- a/tests/integration_tests/utils_tests.py
+++ b/tests/integration_tests/utils_tests.py
@@ -974,7 +974,7 @@ def test_get_form_data_corrupted_json(self) -> None:
def test_log_this(self) -> None:
# TODO: Add additional scenarios.
self.login(username="admin")
- slc = self.get_slice("Girls", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share", db.session)
dashboard_id = 1
assert slc.viz is not None
diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py
index ac390b3976a04..f1665e96888d0 100644
--- a/tests/integration_tests/viz_tests.py
+++ b/tests/integration_tests/viz_tests.py
@@ -45,7 +45,7 @@ def test_constructor_exception_no_datasource(self):
viz.BaseViz(datasource, form_data)
def test_process_metrics(self):
- # test TableViz metrics in correct order
+ # test TimeTableViz metrics in correct order
form_data = {
"url_params": {},
"row_limit": 500,
@@ -55,7 +55,7 @@ def test_process_metrics(self):
"granularity_sqla": "year",
"page_length": 0,
"all_columns": [],
- "viz_type": "table",
+ "viz_type": "time_table",
"since": "2014-01-01",
"until": "2014-01-02",
"metrics": ["sum__SP_POP_TOTL", "SUM(SE_PRM_NENR_MA)", "SUM(SP_URB_TOTL)"],
@@ -177,273 +177,6 @@ def test_cache_timeout(self):
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout
-class TestTableViz(SupersetTestCase):
- def test_get_data_applies_percentage(self):
- form_data = {
- "groupby": ["groupA", "groupB"],
- "metrics": [
- {
- "expressionType": "SIMPLE",
- "aggregate": "SUM",
- "label": "SUM(value1)",
- "column": {"column_name": "value1", "type": "DOUBLE"},
- },
- "count",
- "avg__C",
- ],
- "percent_metrics": [
- {
- "expressionType": "SIMPLE",
- "aggregate": "SUM",
- "label": "SUM(value1)",
- "column": {"column_name": "value1", "type": "DOUBLE"},
- },
- "avg__B",
- ],
- }
- datasource = self.get_datasource_mock()
-
- df = pd.DataFrame(
- {
- "SUM(value1)": [15, 20, 25, 40],
- "avg__B": [10, 20, 5, 15],
- "avg__C": [11, 22, 33, 44],
- "count": [6, 7, 8, 9],
- "groupA": ["A", "B", "C", "C"],
- "groupB": ["x", "x", "y", "z"],
- }
- )
-
- test_viz = viz.TableViz(datasource, form_data)
- data = test_viz.get_data(df)
- # Check method correctly transforms data and computes percents
- self.assertEqual(
- [
- "groupA",
- "groupB",
- "SUM(value1)",
- "count",
- "avg__C",
- "%SUM(value1)",
- "%avg__B",
- ],
- list(data["columns"]),
- )
- expected = [
- {
- "groupA": "A",
- "groupB": "x",
- "SUM(value1)": 15,
- "count": 6,
- "avg__C": 11,
- "%SUM(value1)": 0.15,
- "%avg__B": 0.2,
- },
- {
- "groupA": "B",
- "groupB": "x",
- "SUM(value1)": 20,
- "count": 7,
- "avg__C": 22,
- "%SUM(value1)": 0.2,
- "%avg__B": 0.4,
- },
- {
- "groupA": "C",
- "groupB": "y",
- "SUM(value1)": 25,
- "count": 8,
- "avg__C": 33,
- "%SUM(value1)": 0.25,
- "%avg__B": 0.1,
- },
- {
- "groupA": "C",
- "groupB": "z",
- "SUM(value1)": 40,
- "count": 9,
- "avg__C": 44,
- "%SUM(value1)": 0.4,
- "%avg__B": 0.3,
- },
- ]
- self.assertEqual(expected, data["records"])
-
- def test_parse_adhoc_filters(self):
- form_data = {
- "metrics": [
- {
- "expressionType": "SIMPLE",
- "aggregate": "SUM",
- "label": "SUM(value1)",
- "column": {"column_name": "value1", "type": "DOUBLE"},
- }
- ],
- "adhoc_filters": [
- {
- "expressionType": "SIMPLE",
- "clause": "WHERE",
- "subject": "value2",
- "operator": ">",
- "comparator": "100",
- },
- {
- "expressionType": "SQL",
- "clause": "HAVING",
- "sqlExpression": "SUM(value1) > 5",
- },
- {
- "expressionType": "SQL",
- "clause": "WHERE",
- "sqlExpression": "value3 in ('North America')",
- },
- ],
- }
- datasource = self.get_datasource_mock()
- test_viz = viz.TableViz(datasource, form_data)
- query_obj = test_viz.query_obj()
- self.assertEqual(
- [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"]
- )
- self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"])
- self.assertEqual("(SUM(value1) > 5)", query_obj["extras"]["having"])
-
- def test_adhoc_filters_overwrite_legacy_filters(self):
- form_data = {
- "metrics": [
- {
- "expressionType": "SIMPLE",
- "aggregate": "SUM",
- "label": "SUM(value1)",
- "column": {"column_name": "value1", "type": "DOUBLE"},
- }
- ],
- "adhoc_filters": [
- {
- "expressionType": "SIMPLE",
- "clause": "WHERE",
- "subject": "value2",
- "operator": ">",
- "comparator": "100",
- },
- {
- "expressionType": "SQL",
- "clause": "WHERE",
- "sqlExpression": "value3 in ('North America')",
- },
- ],
- "having": "SUM(value1) > 5",
- }
- datasource = self.get_datasource_mock()
- test_viz = viz.TableViz(datasource, form_data)
- query_obj = test_viz.query_obj()
- self.assertEqual(
- [{"col": "value2", "val": "100", "op": ">"}], query_obj["filter"]
- )
- self.assertEqual("(value3 in ('North America'))", query_obj["extras"]["where"])
- self.assertEqual("", query_obj["extras"]["having"])
-
- def test_query_obj_merges_percent_metrics(self):
- datasource = self.get_datasource_mock()
- form_data = {
- "metrics": ["sum__A", "count", "avg__C"],
- "percent_metrics": ["sum__A", "avg__B", "max__Y"],
- }
- test_viz = viz.TableViz(datasource, form_data)
- query_obj = test_viz.query_obj()
- self.assertEqual(
- ["sum__A", "count", "avg__C", "avg__B", "max__Y"], query_obj["metrics"]
- )
-
- def test_query_obj_throws_columns_and_metrics(self):
- datasource = self.get_datasource_mock()
- form_data = {"all_columns": ["A", "B"], "metrics": ["x", "y"]}
- with self.assertRaises(Exception):
- test_viz = viz.TableViz(datasource, form_data)
- test_viz.query_obj()
- del form_data["metrics"]
- form_data["groupby"] = ["B", "C"]
- with self.assertRaises(Exception):
- test_viz = viz.TableViz(datasource, form_data)
- test_viz.query_obj()
-
- @patch("superset.viz.BaseViz.query_obj")
- def test_query_obj_merges_all_columns(self, super_query_obj):
- datasource = self.get_datasource_mock()
- form_data = {
- "all_columns": ["colA", "colB", "colC"],
- "order_by_cols": ['["colA", "colB"]', '["colC"]'],
- }
- super_query_obj.return_value = {
- "columns": ["colD", "colC"],
- "groupby": ["colA", "colB"],
- }
- test_viz = viz.TableViz(datasource, form_data)
- query_obj = test_viz.query_obj()
- self.assertEqual(form_data["all_columns"], query_obj["columns"])
- self.assertEqual([], query_obj["groupby"])
- self.assertEqual([["colA", "colB"], ["colC"]], query_obj["orderby"])
-
- def test_query_obj_uses_sortby(self):
- datasource = self.get_datasource_mock()
- form_data = {
- "metrics": ["colA", "colB"],
- "order_desc": False,
- }
-
- def run_test(metric):
- form_data["timeseries_limit_metric"] = metric
- test_viz = viz.TableViz(datasource, form_data)
- query_obj = test_viz.query_obj()
- self.assertEqual(["colA", "colB", metric], query_obj["metrics"])
- self.assertEqual([(metric, True)], query_obj["orderby"])
-
- run_test("simple_metric")
- run_test(
- {
- "label": "adhoc_metric",
- "expressionType": "SIMPLE",
- "aggregate": "SUM",
- "column": {
- "column_name": "sort_column",
- },
- }
- )
-
- def test_should_be_timeseries_raises_when_no_granularity(self):
- datasource = self.get_datasource_mock()
- form_data = {"include_time": True}
- with self.assertRaises(Exception):
- test_viz = viz.TableViz(datasource, form_data)
- test_viz.should_be_timeseries()
-
- def test_adhoc_metric_with_sortby(self):
- metrics = [
- {
- "expressionType": "SIMPLE",
- "aggregate": "SUM",
- "label": "sum_value",
- "column": {"column_name": "value1", "type": "DOUBLE"},
- }
- ]
- form_data = {
- "metrics": metrics,
- "timeseries_limit_metric": {
- "expressionType": "SIMPLE",
- "aggregate": "SUM",
- "label": "SUM(value1)",
- "column": {"column_name": "value1", "type": "DOUBLE"},
- },
- "order_desc": False,
- }
-
- df = pd.DataFrame({"SUM(value1)": [15], "sum_value": [15]})
- datasource = self.get_datasource_mock()
- test_viz = viz.TableViz(datasource, form_data)
- data = test_viz.get_data(df)
- self.assertEqual(["sum_value"], data["columns"])
-
-
class TestDistBarViz(SupersetTestCase):
def test_groupby_nulls(self):
form_data = {
@@ -1311,7 +1044,7 @@ def test_apply_rolling(self):
data={"y": [1.0, 2.0, 3.0, 4.0]},
)
self.assertEqual(
- viz.BigNumberViz(
+ viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
@@ -1325,7 +1058,7 @@ def test_apply_rolling(self):
[1.0, 3.0, 6.0, 10.0],
)
self.assertEqual(
- viz.BigNumberViz(
+ viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
@@ -1339,7 +1072,7 @@ def test_apply_rolling(self):
[1.0, 3.0, 5.0, 7.0],
)
self.assertEqual(
- viz.BigNumberViz(
+ viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
@@ -1361,7 +1094,7 @@ def test_apply_rolling_without_data(self):
),
data={"y": [1.0, 2.0, 3.0, 4.0]},
)
- test_viz = viz.BigNumberViz(
+ test_viz = viz.NVD3TimeSeriesViz(
datasource,
{
"metrics": ["y"],
@@ -1374,34 +1107,6 @@ def test_apply_rolling_without_data(self):
test_viz.apply_rolling(df)
-class TestBigNumberViz(SupersetTestCase):
- def test_get_data(self):
- datasource = self.get_datasource_mock()
- df = pd.DataFrame(
- data={
- DTTM_ALIAS: pd.to_datetime(
- ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]
- ),
- "y": [1.0, 2.0, 3.0, 4.0],
- }
- )
- data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df)
- self.assertEqual(data[2], {DTTM_ALIAS: pd.Timestamp("2019-01-05"), "y": 3})
-
- def test_get_data_with_none(self):
- datasource = self.get_datasource_mock()
- df = pd.DataFrame(
- data={
- DTTM_ALIAS: pd.to_datetime(
- ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]
- ),
- "y": [1.0, 2.0, None, 4.0],
- }
- )
- data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df)
- assert np.isnan(data[2]["y"])
-
-
class TestFilterBoxViz(SupersetTestCase):
def test_get_data(self):
form_data = {