diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 3b404743317a0..ff90cb2a56fef 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -2482,6 +2482,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): __tablename__ = "row_level_security_filters" id = Column(Integer, primary_key=True) + name = Column(String(255), unique=True, nullable=False) + description = Column(Text) filter_type = Column( Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType]) ) @@ -2494,5 +2496,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): tables = relationship( SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters" ) - clause = Column(Text, nullable=False) diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index ef5afa5d05b28..e3a3725f311b4 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -26,7 +26,7 @@ from flask_appbuilder.security.decorators import has_access from flask_babel import lazy_gettext as _ from wtforms.ext.sqlalchemy.fields import QuerySelectField -from wtforms.validators import Regexp +from wtforms.validators import DataRequired, Regexp from superset import app, db from superset.connectors.base.views import DatasourceModelView @@ -47,6 +47,19 @@ logger = logging.getLogger(__name__) +class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-methods + """ + Select required flag on the input field will not work well on Chrome + Console error: + An invalid form control with name='tables' is not focusable. + + This makes a simple override to the DataRequired to be used specifically with + select fields + """ + + field_flags = () + + class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): datamodel = SQLAInterface(models.TableColumn) # TODO TODO, review need for this on related_views @@ -272,21 +285,39 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin): edit_title = _("Edit Row level security filter") list_columns = [ + "name", "filter_type", "tables", "roles", - "group_key", "clause", "creator", "modified", ] - order_columns = ["filter_type", "group_key", "clause", "modified"] - edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"] + order_columns = ["name", "filter_type", "clause", "modified"] + edit_columns = [ + "name", + "description", + "filter_type", + "tables", + "roles", + "group_key", + "clause", + ] show_columns = edit_columns - search_columns = ("filter_type", "tables", "roles", "group_key", "clause") + search_columns = ( + "name", + "description", + "filter_type", + "tables", + "roles", + "group_key", + "clause", + ) add_columns = edit_columns base_order = ("changed_on", "desc") description_columns = { + "name": _("Choose a unique name"), + "description": _("Optionally add a detailed description"), "filter_type": _( "Regular filters add where clauses to queries if a user belongs to a " "role referenced in the filter. Base filters apply filters to all queries " @@ -319,12 +350,16 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin): ), } label_columns = { + "name": _("Name"), + "description": _("Description"), "tables": _("Tables"), "roles": _("Roles"), "clause": _("Clause"), "creator": _("Creator"), "modified": _("Modified"), } + validators_columns = {"tables": [SelectDataRequired()]} + if app.config["RLS_FORM_QUERY_REL_FIELDS"]: add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"] edit_form_query_rel_fields = add_form_query_rel_fields diff --git a/superset/migrations/versions/2022-06-19_16-17_f3afaf1f11f0_add_unique_name_desc_rls.py b/superset/migrations/versions/2022-06-19_16-17_f3afaf1f11f0_add_unique_name_desc_rls.py new file mode 100644 index 0000000000000..0d8b3334a4fcf --- /dev/null +++ b/superset/migrations/versions/2022-06-19_16-17_f3afaf1f11f0_add_unique_name_desc_rls.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add_unique_name_desc_rls + +Revision ID: f3afaf1f11f0 +Revises: e786798587de +Create Date: 2022-06-19 16:17:23.318618 + +""" + +# revision identifiers, used by Alembic. +revision = "f3afaf1f11f0" +down_revision = "e786798587de" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session + +Base = declarative_base() + + +class RowLevelSecurityFilter(Base): + __tablename__ = "row_level_security_filters" + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(255), unique=True, nullable=False) + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + bind = op.get_bind() + session = Session(bind=bind) + + op.add_column( + "row_level_security_filters", sa.Column("name", sa.String(length=255)) + ) + op.add_column( + "row_level_security_filters", sa.Column("description", sa.Text(), nullable=True) + ) + + # Set initial default names make sure we can have unique non null values + all_rls = session.query(RowLevelSecurityFilter).all() + for rls in all_rls: + rls.name = f"rls-{rls.id}" + session.commit() + + # Now it's safe so set non-null and unique + # add unique constraint + with op.batch_alter_table("row_level_security_filters") as batch_op: + # batch mode is required for sqlite + batch_op.alter_column( + "name", + existing_type=sa.String(255), + nullable=False, + ) + batch_op.create_unique_constraint("uq_rls_name", ["name"]) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("uq_rls_name", "row_level_security_filters", type_="unique") + op.drop_column("row_level_security_filters", "description") + op.drop_column("row_level_security_filters", "name") + # ### end Alembic commands ### diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 1e46bfb996c5b..ebd95cae39bd7 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -25,7 +25,6 @@ from superset import db, security_manager from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.security.guest_token import ( - GuestTokenRlsRule, GuestTokenResourceType, GuestUser, ) @@ -82,6 +81,7 @@ def setUp(self): # Create regular RowLevelSecurityFilter (energy_usage, unicode_test) self.rls_entry1 = RowLevelSecurityFilter() + self.rls_entry1.name = "rls_entry1" self.rls_entry1.tables.extend( session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) @@ -96,6 +96,7 @@ def setUp(self): # Create regular RowLevelSecurityFilter (birth_names name starts with A or B) self.rls_entry2 = RowLevelSecurityFilter() + self.rls_entry2.name = "rls_entry2" self.rls_entry2.tables.extend( session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["birth_names"])) @@ -109,6 +110,7 @@ def setUp(self): # Create Regular RowLevelSecurityFilter (birth_names name starts with Q) self.rls_entry3 = RowLevelSecurityFilter() + self.rls_entry3.name = "rls_entry3" self.rls_entry3.tables.extend( session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["birth_names"])) @@ -122,6 +124,7 @@ def setUp(self): # Create Base RowLevelSecurityFilter (birth_names boys) self.rls_entry4 = RowLevelSecurityFilter() + self.rls_entry4.name = "rls_entry4" self.rls_entry4.tables.extend( session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["birth_names"])) @@ -146,6 +149,94 @@ def tearDown(self): session.delete(self.get_user("NoRlsRoleUser")) session.commit() + @pytest.fixture() + def create_dataset(self): + with self.create_app().app_context(): + + dataset = SqlaTable(database_id=1, schema=None, table_name="table1") + db.session.add(dataset) + db.session.flush() + db.session.commit() + + yield dataset + + # rollback changes (assuming cascade delete) + db.session.delete(dataset) + db.session.commit() + + def _get_test_dataset(self): + return ( + db.session.query(SqlaTable).filter(SqlaTable.table_name == "table1") + ).one_or_none() + + @pytest.mark.usefixtures("create_dataset") + def test_model_view_rls_add_success(self): + self.login(username="admin") + test_dataset = self._get_test_dataset() + rv = self.client.post( + "/rowlevelsecurityfiltersmodelview/add", + data=dict( + name="rls1", + description="Some description", + filter_type="Regular", + tables=[test_dataset.id], + roles=[security_manager.find_role("Alpha").id], + group_key="group_key_1", + clause="client_id=1", + ), + follow_redirects=True, + ) + self.assertEqual(rv.status_code, 200) + rls1 = ( + db.session.query(RowLevelSecurityFilter).filter_by(name="rls1") + ).one_or_none() + assert rls1 is not None + + # Revert data changes + db.session.delete(rls1) + db.session.commit() + + @pytest.mark.usefixtures("create_dataset") + def test_model_view_rls_add_name_unique(self): + self.login(username="admin") + test_dataset = self._get_test_dataset() + rv = self.client.post( + "/rowlevelsecurityfiltersmodelview/add", + data=dict( + name="rls_entry1", + description="Some description", + filter_type="Regular", + tables=[test_dataset.id], + roles=[security_manager.find_role("Alpha").id], + group_key="group_key_1", + clause="client_id=1", + ), + follow_redirects=True, + ) + self.assertEqual(rv.status_code, 200) + data = rv.data.decode("utf-8") + assert "Already exists." in data + + @pytest.mark.usefixtures("create_dataset") + def test_model_view_rls_add_tables_required(self): + self.login(username="admin") + rv = self.client.post( + "/rowlevelsecurityfiltersmodelview/add", + data=dict( + name="rls1", + description="Some description", + filter_type="Regular", + tables=[], + roles=[security_manager.find_role("Alpha").id], + group_key="group_key_1", + clause="client_id=1", + ), + follow_redirects=True, + ) + self.assertEqual(rv.status_code, 200) + data = rv.data.decode("utf-8") + assert "This field is required." in data + @pytest.mark.usefixtures("load_energy_table_with_slice") def test_rls_filter_alters_energy_query(self): g.user = self.get_user(username="alpha") diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 9950fb9fedda5..c5bfa4a16d600 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -186,6 +186,7 @@ def test_sql_lab_insert_rls( # now with RLS rls = RowLevelSecurityFilter( + name="sqllab_rls1", filter_type=RowLevelSecurityFilterType.REGULAR, tables=[SqlaTable(database_id=1, schema=None, table_name="t")], roles=[admin.roles[0]],