Skip to content

Commit

Permalink
add tests + refactor fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro committed Dec 3, 2020
1 parent 739bf15 commit 70114ce
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 78 deletions.
1 change: 1 addition & 0 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def data(self) -> Response:
"""
if request.is_json:
json_body = request.json
print(json_body)
elif request.form.get("form_data"):
# CSV export submits regular form data
json_body = json.loads(request.form["form_data"])
Expand Down
5 changes: 2 additions & 3 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from superset.common.query_context import QueryContext
from superset.utils import schema as utils
from superset.utils.core import (
AnnotationType,
FilterOperator,
PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation,
Expand Down Expand Up @@ -783,9 +784,7 @@ class ChartDataExtrasSchema(Schema):
class AnnotationLayerSchema(Schema):
annotationType = fields.String(
description="Type of annotation layer",
validate=validate.OneOf(
choices=("EVENT", "FORMULA", "INTERVAL", "TIME_SERIES",)
),
validate=validate.OneOf(choices=[ann.value for ann in AnnotationType]),
)
color = fields.String(description="Layer color", allow_none=True,)
descriptionColumns = fields.List(
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
)
return annotation_data

def get_df_payload( # pylint: disable=too-many-statements
def get_df_payload( # pylint: disable=too-many-statements,too-many-locals
self, query_obj: QueryObject, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
Expand Down
7 changes: 7 additions & 0 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,13 @@ class ExtraFiltersTimeColumnType(str, Enum):
TIME_RANGE = "__time_range"


class AnnotationType(str, Enum):
FORMULA = "FORMULA"
INTERVAL = "INTERVAL"
EVENT = "EVENT"
TIME_SERIES = "TIME_SERIES"


def is_test() -> bool:
return strtobool(os.environ.get("SUPERSET_TESTENV", "false"))

Expand Down
79 changes: 9 additions & 70 deletions tests/annotation_layers/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
from datetime import datetime
from typing import Optional
import json

import pytest
Expand All @@ -29,77 +27,17 @@
from superset.models.annotations import Annotation, AnnotationLayer

from tests.base_tests import SupersetTestCase

from tests.annotation_layers.fixtures import (
create_annotation_layers,
get_end_dttm,
get_start_dttm,
)

ANNOTATION_LAYERS_COUNT = 10
ANNOTATIONS_COUNT = 5


class TestAnnotationLayerApi(SupersetTestCase):
def insert_annotation_layer(
self, name: str = "", descr: str = ""
) -> AnnotationLayer:
annotation_layer = AnnotationLayer(name=name, descr=descr,)
db.session.add(annotation_layer)
db.session.commit()
return annotation_layer

def insert_annotation(
self,
layer: AnnotationLayer,
short_descr: str,
long_descr: str,
json_metadata: Optional[str] = "",
start_dttm: Optional[datetime] = None,
end_dttm: Optional[datetime] = None,
) -> Annotation:
annotation = Annotation(
layer=layer,
short_descr=short_descr,
long_descr=long_descr,
json_metadata=json_metadata,
start_dttm=start_dttm,
end_dttm=end_dttm,
)
db.session.add(annotation)
db.session.commit()
return annotation

@pytest.fixture()
def create_annotation_layers(self):
"""
Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations
and a final one with ANNOTATION_COUNT childs
:return:
"""
with self.create_app().app_context():
annotation_layers = []
annotations = []
for cx in range(ANNOTATION_LAYERS_COUNT - 1):
annotation_layers.append(
self.insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}")
)
layer_with_annotations = self.insert_annotation_layer(
"layer_with_annotations"
)
annotation_layers.append(layer_with_annotations)
for cx in range(ANNOTATIONS_COUNT):
annotations.append(
self.insert_annotation(
layer_with_annotations,
short_descr=f"short_descr{cx}",
long_descr=f"long_descr{cx}",
)
)
yield annotation_layers

# rollback changes
for annotation_layer in annotation_layers:
db.session.delete(annotation_layer)
for annotation in annotations:
db.session.delete(annotation)
db.session.commit()

@staticmethod
def get_layer_with_annotation() -> AnnotationLayer:
return (
Expand Down Expand Up @@ -421,9 +359,10 @@ def test_get_annotation(self):
"""
Annotation API: Test get annotation
"""
annotation_id = 1
annotation = (
db.session.query(Annotation)
.filter(Annotation.short_descr == "short_descr1")
.filter(Annotation.short_descr == f"short_descr{annotation_id}")
.one_or_none()
)

Expand All @@ -436,12 +375,12 @@ def test_get_annotation(self):

expected_result = {
"id": annotation.id,
"end_dttm": None,
"end_dttm": get_end_dttm(annotation_id).isoformat(),
"json_metadata": "",
"layer": {"id": annotation.layer_id, "name": "layer_with_annotations"},
"long_descr": annotation.long_descr,
"short_descr": annotation.short_descr,
"start_dttm": None,
"start_dttm": get_start_dttm(annotation_id).isoformat(),
}

data = json.loads(rv.data.decode("utf-8"))
Expand Down
101 changes: 101 additions & 0 deletions tests/annotation_layers/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import pytest
from datetime import datetime
from typing import Optional

from superset import db
from superset.models.annotations import Annotation, AnnotationLayer

from tests.test_app import app


ANNOTATION_LAYERS_COUNT = 10
ANNOTATIONS_COUNT = 5


def get_start_dttm(annotation_id: int) -> datetime:
return datetime(1990 + annotation_id, 1, 1)


def get_end_dttm(annotation_id: int) -> datetime:
return datetime(1990 + annotation_id, 7, 1)


def _insert_annotation_layer(name: str = "", descr: str = "") -> AnnotationLayer:
annotation_layer = AnnotationLayer(name=name, descr=descr,)
db.session.add(annotation_layer)
db.session.commit()
return annotation_layer


def _insert_annotation(
layer: AnnotationLayer,
short_descr: str,
long_descr: str,
json_metadata: Optional[str] = "",
start_dttm: Optional[datetime] = None,
end_dttm: Optional[datetime] = None,
) -> Annotation:
annotation = Annotation(
layer=layer,
short_descr=short_descr,
long_descr=long_descr,
json_metadata=json_metadata,
start_dttm=start_dttm,
end_dttm=end_dttm,
)
db.session.add(annotation)
db.session.commit()
return annotation


@pytest.fixture()
def create_annotation_layers():
"""
Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations
and a final one with ANNOTATION_COUNT childs
:return:
"""
with app.app_context():
annotation_layers = []
annotations = []
for cx in range(ANNOTATION_LAYERS_COUNT - 1):
annotation_layers.append(
_insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}")
)
layer_with_annotations = _insert_annotation_layer("layer_with_annotations")
annotation_layers.append(layer_with_annotations)
for cx in range(ANNOTATIONS_COUNT):
annotations.append(
_insert_annotation(
layer_with_annotations,
short_descr=f"short_descr{cx}",
long_descr=f"long_descr{cx}",
start_dttm=get_start_dttm(cx),
end_dttm=get_end_dttm(cx),
)
)
yield annotation_layers

# rollback changes
for annotation_layer in annotation_layers:
db.session.delete(annotation_layer)
for annotation in annotations:
db.session.delete(annotation)
db.session.commit()
52 changes: 48 additions & 4 deletions tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@
from sqlalchemy import and_
from sqlalchemy.sql import func

from superset.connectors.sqla.models import SqlaTable
from superset.utils.core import get_example_database
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
from tests.test_app import app
from superset.connectors.sqla.models import SqlaTable
from superset.utils.core import AnnotationType, get_example_database
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.models.annotations import AnnotationLayer
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.reports import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.utils import core as utils

from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
from tests.fixtures.importexport import (
Expand All @@ -50,7 +51,9 @@
dataset_config,
dataset_metadata_config,
)
from tests.fixtures.query_context import get_query_context
from tests.fixtures.query_context import get_query_context, ANNOTATION_LAYERS
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
from tests.annotation_layers.fixtures import create_annotation_layers

CHART_DATA_URI = "api/v1/chart/data"
CHARTS_FIXTURE_COUNT = 10
Expand Down Expand Up @@ -1383,3 +1386,44 @@ def test_import_chart_invalid(self):
assert response == {
"message": {"metadata.yaml": {"type": ["Must be equal to Slice."]}}
}

@pytest.mark.usefixtures("create_annotation_layers")
def test_chart_data_annotations(self):
"""
Chart data API: Test chart data query
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)

annotation_layers = []
request_payload["queries"][0]["annotation_layers"] = annotation_layers

# formula
annotation_layers.append(ANNOTATION_LAYERS[AnnotationType.FORMULA])

# interval
interval_layer = (
db.session.query(AnnotationLayer)
.filter(AnnotationLayer.name == "name1")
.one()
)
interval = ANNOTATION_LAYERS[AnnotationType.INTERVAL]
interval["value"] = interval_layer.id
annotation_layers.append(interval)

# event
event_layer = (
db.session.query(AnnotationLayer)
.filter(AnnotationLayer.name == "name2")
.one()
)
event = ANNOTATION_LAYERS[AnnotationType.EVENT]
event["value"] = event_layer.id
annotation_layers.append(event)

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
# response should only contain interval and event data, not formula
self.assertEqual(len(data["result"][0]["annotation_data"]), 2)
Loading

0 comments on commit 70114ce

Please sign in to comment.