Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add name, description and non null tables to RLS #20432

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):

__tablename__ = "row_level_security_filters"
id = Column(Integer, primary_key=True)
name = Column(String(255), unique=True, nullable=False)
description = Column(Text)
filter_type = Column(
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
)
Expand All @@ -2494,5 +2496,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
tables = relationship(
SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters"
)

clause = Column(Text, nullable=False)
45 changes: 40 additions & 5 deletions superset/connectors/sqla/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flask_appbuilder.security.decorators import has_access
from flask_babel import lazy_gettext as _
from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import Regexp
from wtforms.validators import DataRequired, Regexp

from superset import app, db
from superset.connectors.base.views import DatasourceModelView
Expand All @@ -47,6 +47,19 @@
logger = logging.getLogger(__name__)


class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-methods
"""
Select required flag on the input field will not work well on Chrome
Console error:
An invalid form control with name='tables' is not focusable.

This makes a simple override to the DataRequired to be used specifically with
select fields
"""

field_flags = ()


class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
datamodel = SQLAInterface(models.TableColumn)
# TODO TODO, review need for this on related_views
Expand Down Expand Up @@ -272,21 +285,39 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
edit_title = _("Edit Row level security filter")

list_columns = [
"name",
"filter_type",
"tables",
"roles",
"group_key",
"clause",
"creator",
"modified",
]
order_columns = ["filter_type", "group_key", "clause", "modified"]
edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"]
order_columns = ["name", "filter_type", "clause", "modified"]
edit_columns = [
"name",
"description",
"filter_type",
"tables",
"roles",
"group_key",
"clause",
]
show_columns = edit_columns
search_columns = ("filter_type", "tables", "roles", "group_key", "clause")
search_columns = (
"name",
"description",
"filter_type",
"tables",
"roles",
"group_key",
"clause",
)
add_columns = edit_columns
base_order = ("changed_on", "desc")
description_columns = {
"name": _("Choose a unique name"),
"description": _("Optionally add a detailed description"),
"filter_type": _(
"Regular filters add where clauses to queries if a user belongs to a "
"role referenced in the filter. Base filters apply filters to all queries "
Expand Down Expand Up @@ -319,12 +350,16 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
),
}
label_columns = {
"name": _("Name"),
"description": _("Description"),
"tables": _("Tables"),
"roles": _("Roles"),
"clause": _("Clause"),
"creator": _("Creator"),
"modified": _("Modified"),
}
validators_columns = {"tables": [SelectDataRequired()]}

if app.config["RLS_FORM_QUERY_REL_FIELDS"]:
add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"]
edit_form_query_rel_fields = add_form_query_rel_fields
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_unique_name_desc_rls

Revision ID: f3afaf1f11f0
Revises: e786798587de
Create Date: 2022-06-19 16:17:23.318618

"""

# revision identifiers, used by Alembic.
revision = "f3afaf1f11f0"
down_revision = "e786798587de"

import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

Base = declarative_base()


class RowLevelSecurityFilter(Base):
__tablename__ = "row_level_security_filters"
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(255), unique=True, nullable=False)


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
bind = op.get_bind()
session = Session(bind=bind)

op.add_column(
"row_level_security_filters", sa.Column("name", sa.String(length=255))
)
op.add_column(
"row_level_security_filters", sa.Column("description", sa.Text(), nullable=True)
)

# Set initial default names make sure we can have unique non null values
all_rls = session.query(RowLevelSecurityFilter).all()
for rls in all_rls:
rls.name = f"rls-{rls.id}"
session.commit()

# Now it's safe so set non-null and unique
# add unique constraint
with op.batch_alter_table("row_level_security_filters") as batch_op:
# batch mode is required for sqlite
batch_op.alter_column(
"name",
existing_type=sa.String(255),
nullable=False,
)
batch_op.create_unique_constraint("uq_rls_name", ["name"])
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("uq_rls_name", "row_level_security_filters", type_="unique")
op.drop_column("row_level_security_filters", "description")
op.drop_column("row_level_security_filters", "name")
# ### end Alembic commands ###
93 changes: 92 additions & 1 deletion tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from superset import db, security_manager
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.security.guest_token import (
GuestTokenRlsRule,
GuestTokenResourceType,
GuestUser,
)
Expand Down Expand Up @@ -82,6 +81,7 @@ def setUp(self):

# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
self.rls_entry1 = RowLevelSecurityFilter()
self.rls_entry1.name = "rls_entry1"
self.rls_entry1.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
Expand All @@ -96,6 +96,7 @@ def setUp(self):

# Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
self.rls_entry2 = RowLevelSecurityFilter()
self.rls_entry2.name = "rls_entry2"
self.rls_entry2.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
Expand All @@ -109,6 +110,7 @@ def setUp(self):

# Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
self.rls_entry3 = RowLevelSecurityFilter()
self.rls_entry3.name = "rls_entry3"
self.rls_entry3.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
Expand All @@ -122,6 +124,7 @@ def setUp(self):

# Create Base RowLevelSecurityFilter (birth_names boys)
self.rls_entry4 = RowLevelSecurityFilter()
self.rls_entry4.name = "rls_entry4"
self.rls_entry4.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
Expand All @@ -146,6 +149,94 @@ def tearDown(self):
session.delete(self.get_user("NoRlsRoleUser"))
session.commit()

@pytest.fixture()
def create_dataset(self):
with self.create_app().app_context():

dataset = SqlaTable(database_id=1, schema=None, table_name="table1")
db.session.add(dataset)
db.session.flush()
db.session.commit()

yield dataset

# rollback changes (assuming cascade delete)
db.session.delete(dataset)
db.session.commit()

def _get_test_dataset(self):
return (
db.session.query(SqlaTable).filter(SqlaTable.table_name == "table1")
).one_or_none()

@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_success(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
rls1 = (
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
).one_or_none()
assert rls1 is not None

# Revert data changes
db.session.delete(rls1)
db.session.commit()

@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_name_unique(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls_entry1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "Already exists." in data

@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self):
self.login(username="admin")
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "This field is required." in data

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_alters_energy_query(self):
g.user = self.get_user(username="alpha")
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/sql_lab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_sql_lab_insert_rls(

# now with RLS
rls = RowLevelSecurityFilter(
name="sqllab_rls1",
filter_type=RowLevelSecurityFilterType.REGULAR,
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
roles=[admin.roles[0]],
Expand Down