Skip to content

Commit

Permalink
fix(explore): unable to update linked charts (#22896)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinpark authored Feb 2, 2023
1 parent deb5109 commit ad1ffbd
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 5 deletions.
5 changes: 4 additions & 1 deletion superset/charts/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def validate(self) -> None:

# Validate/Populate dashboards only if it's a list
if dashboard_ids is not None:
dashboards = DashboardDAO.find_by_ids(dashboard_ids)
dashboards = DashboardDAO.find_by_ids(
dashboard_ids,
skip_base_filter=True,
)
if len(dashboards) != len(dashboard_ids):
exceptions.append(DashboardsNotFoundValidationError())
self._properties["dashboards"] = dashboards
Expand Down
14 changes: 10 additions & 4 deletions superset/dao/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,22 @@ def find_by_id(
return None

@classmethod
def find_by_ids(cls, model_ids: Union[List[str], List[int]]) -> List[Model]:
def find_by_ids(
cls,
model_ids: Union[List[str], List[int]],
session: Session = None,
skip_base_filter: bool = False,
) -> List[Model]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
id_col = getattr(cls.model_cls, cls.id_column_name, None)
if id_col is None:
return []
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
if cls.base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
session = session or db.session
query = session.query(cls.model_cls).filter(id_col.in_(model_ids))
if cls.base_filter and not skip_base_filter:
data_model = SQLAInterface(cls.model_cls, session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
Expand Down
55 changes: 55 additions & 0 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,61 @@ def test_update_chart_not_owned(self):
db.session.delete(user_alpha2)
db.session.commit()

def test_update_chart_linked_with_not_owned_dashboard(self):
"""
Chart API: Test update chart which is linked to not owned dashboard
"""
user_alpha1 = self.create_user(
"alpha1", "password", "Alpha", email="[email protected]"
)
user_alpha2 = self.create_user(
"alpha2", "password", "Alpha", email="[email protected]"
)
chart = self.insert_chart("title", [user_alpha1.id], 1)

original_dashboard = Dashboard()
original_dashboard.dashboard_title = "Original Dashboard"
original_dashboard.slug = "slug"
original_dashboard.owners = [user_alpha1]
original_dashboard.slices = [chart]
original_dashboard.published = False
db.session.add(original_dashboard)

new_dashboard = Dashboard()
new_dashboard.dashboard_title = "Cloned Dashboard"
new_dashboard.slug = "new_slug"
new_dashboard.owners = [user_alpha2]
new_dashboard.slices = [chart]
new_dashboard.published = False
db.session.add(new_dashboard)

self.login(username="alpha1", password="password")
chart_data_with_invalid_dashboard = {
"slice_name": "title1_changed",
"dashboards": [original_dashboard.id, 0],
}
chart_data = {
"slice_name": "title1_changed",
"dashboards": [original_dashboard.id, new_dashboard.id],
}
uri = f"api/v1/chart/{chart.id}"

rv = self.put_assert_metric(uri, chart_data_with_invalid_dashboard, "put")
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"dashboards": ["Dashboards do not exist"]}}
self.assertEqual(response, expected_response)

rv = self.put_assert_metric(uri, chart_data, "put")
self.assertEqual(rv.status_code, 200)

db.session.delete(chart)
db.session.delete(original_dashboard)
db.session.delete(new_dashboard)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
db.session.commit()

def test_update_chart_validate_datasource(self):
"""
Chart API: Test update validate datasource
Expand Down
30 changes: 30 additions & 0 deletions tests/unit_tests/datasets/dao/dao_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,33 @@ def test_datasource_find_by_id_skip_base_filter_not_found(
skip_base_filter=True,
)
assert result is None


def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.dao import DatasetDAO

result = DatasetDAO.find_by_ids(
[1, 125326326],
session=session_with_data,
skip_base_filter=True,
)

assert result
assert [1] == list(map(lambda x: x.id, result))
assert ["my_sqla_table"] == list(map(lambda x: x.table_name, result))
assert isinstance(result[0], SqlaTable)


def test_datasource_find_by_ids_skip_base_filter_not_found(
session_with_data: Session,
) -> None:
from superset.datasets.dao import DatasetDAO

result = DatasetDAO.find_by_ids(
[125326326, 125326326125326326],
session=session_with_data,
skip_base_filter=True,
)

assert len(result) == 0

0 comments on commit ad1ffbd

Please sign in to comment.