Skip to content

Commit

Permalink
Add roles
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Feb 14, 2024
1 parent 4a75f8e commit 472d59a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 65 deletions.
147 changes: 82 additions & 65 deletions src/spyglass/utils/database_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@

from spyglass.utils.logging import logger

GRANT_ALL = "GRANT ALL PRIVILEGES ON "
GRANT_SEL = "GRANT SELECT ON "
CREATE_USR = "CREATE USER IF NOT EXISTS "
TEMP_PASS = " IDENTIFIED BY 'temppass';"
ESC = r"\_%"
SHARED_MODULES = [
"common",
"spikesorting",
Expand All @@ -25,6 +20,12 @@
"waveform",
"mua",
]
GRANT_ALL = "GRANT ALL PRIVILEGES ON "
GRANT_SEL = "GRANT SELECT ON "
CREATE_USR = "CREATE USER IF NOT EXISTS "
CREATE_ROLE = "CREATE ROLE IF NOT EXISTS "
TEMP_PASS = " IDENTIFIED BY 'temppass';"
ESC = r"\_%"


class DatabaseSettings:
Expand All @@ -38,6 +39,12 @@ def __init__(
):
"""Class to manage common database settings
Roles:
- dj_guest: select for all prefix
- dj_collab: select for all prefix, all for user prefix
- dj_user: select for all prefix, all for user prefix, all for shared
- dj_admin: all for all prefix
Parameters
----------
user_name : str, optional
Expand All @@ -61,49 +68,78 @@ def __init__(
self.target_database = target_database or "mysql"

@property
def _add_collab_usr_sql(self):
return [
# Create the user (if not already created) and set the password
f"{CREATE_USR}'{self.user}'@'%'{TEMP_PASS}\n",
# Grant privileges to databases matching the user_name pattern
f"{GRANT_ALL}`{self.user}{ESC}`.* TO '{self.user}'@'%';\n",
# Grant SELECT privileges on all databases
f"{GRANT_SEL}`%`.* TO '{self.user}'@'%';\n",
def _create_roles_sql(self):
guest_role = [
f"{CREATE_ROLE}'dj_guest';\n",
f"{GRANT_SEL}`%`.* TO 'dj_guest';\n",
]
collab_role = [
f"{CREATE_ROLE}'dj_collab';\n",
f"{GRANT_SEL}`%`.* TO 'dj_collab';\n",
] # also gets own prefix below
user_role = [
f"{CREATE_ROLE}'dj_user';\n",
f"{GRANT_SEL}`%`.* TO 'dj_user';\n",
] + [
f"{GRANT_ALL}`{module}`.* TO 'dj_user';\n"
for module in self.shared_modules
] # also gets own prefix below
admin_role = [
f"{CREATE_ROLE}'dj_admin';\n",
f"{GRANT_ALL}`%`.* TO 'dj_admin';\n",
]

def add_collab_user(self):
"""Add collaborator user with full permissions to shared modules"""
file = self.write_temp_file(self._add_collab_usr_sql)
self.exec(file)
return guest_role + collab_role + user_role + admin_role

def _create_user_sql(self, role):
"""Create user and grant role"""
return [
f"{CREATE_USR}'{self.user}'@'%'{TEMP_PASS}\n", # create user
f"GRANT {role} TO '{self.user}'@'%';\n", # grant role
]

@property
def _add_dj_guest_sql(self):
# Note: changing to temppass for uniformity
def _user_prefix_sql(self):
"""Grant user all permissions for user prefix"""
return [
# Create the user (if not already created) and set the password
f"{CREATE_USR}'{self.user}'@'%'{TEMP_PASS}\n",
# Grant privileges
f"{GRANT_SEL}`%`.* TO '{self.user}'@'%';\n",
f"{GRANT_ALL}`{self.user}{ESC}`.* TO '{self.user}'@'%';\n",
]

def add_dj_guest(self, method="file"):
@property
def _add_guest_sql(self):
return self._create_user_sql("dj_guest")

@property
def _add_collab_sql(self):
return self._create_user_sql("dj_collab") + self._user_prefix_sql

@property
def _add_user_sql(self):
return self._create_user_sql("dj_user") + self._user_prefix_sql

def _add_module_sql(self, module_name):
return [f"{GRANT_ALL}`{module_name}{ESC}`.* TO dj_user;\n"]

def add_collab(self):
"""Add collaborator user with full permissions to shared modules"""
file = self.write_temp_file(self._add_collab_sql)
self.exec(file)

def add_guest(self):
"""Add guest user with select permissions to shared modules"""
file = self.write_temp_file(self._add_dj_guest_sql)
file = self.write_temp_file(self._add_guest_sql)
self.exec(file)

def _find_group(self):
# find the kachery-users group
groups = grp.getgrall()
groups = grp.getgrall() # find the kachery-users group
group_found = False # initialize the flag as False
for group in groups:
if group.gr_name == self.target_group:
group_found = (
True # set the flag to True when the group is found
)
# set the flag to True when the group is found
group_found = True
break

# Check if the group was found
if not group_found:
if not group_found: # Check if the group was found
if self.debug:
logger.info(f"All groups: {[g.gr_name for g in groups]}")
sys.exit(
Expand All @@ -112,35 +148,12 @@ def _find_group(self):

return group

def _add_module_sql(self, module_name, group):
return [
f"{GRANT_ALL}`{module_name}{ESC}`.* TO `{user}`@'%';\n"
# get a list of usernames
for user in group.gr_mem
]

def add_module(self, module_name):
"""Add module to database. Grant permissions to all users in group"""
logger.info(f"Granting everyone permissions to module {module_name}")
group = self._find_group()
file = self.write_temp_file(self._add_module_sql(module_name, group))
file = self.write_temp_file(self._add_module_sql(module_name))
self.exec(file)

@property
def _add_dj_user_sql(self):
return (
[
f"{CREATE_USR}'{self.user}'@'%' "
+ "IDENTIFIED BY 'temppass';\n",
f"{GRANT_ALL}`{self.user}{ESC}`.* TO '{self.user}'@'%';" + "\n",
]
+ [
f"{GRANT_ALL}`{module}`.* TO '{self.user}'@'%';\n"
for module in self.shared_modules
]
+ [f"{GRANT_SEL}`%`.* TO '{self.user}'@'%';\n"]
)

def add_dj_user(self, check_exists=True):
"""Add user to database with permissions to shared modules"""
if check_exists:
Expand All @@ -149,10 +162,15 @@ def add_dj_user(self, check_exists=True):
logger.info("Creating database user ", self.user)
else:
sys.exit(
f"Error: could not find {self.user} in home dir: {user_home}"
f"Error: couldn't find {self.user} in home dir: {user_home}"
)

file = self.write_temp_file(self._add_dj_user_sql)
file = self.write_temp_file(self._add_user_sql)
self.exec(file)

def add_roles(self):
"""Add roles to database"""
file = self.write_temp_file(self._create_roles_sql)
self.exec(file)

def write_temp_file(self, content: list) -> tempfile.NamedTemporaryFile:
Expand All @@ -176,11 +194,10 @@ def exec(self, file):
if self.debug:
return

if self.target_database == "mysql":
cmd = f"mysql -p -h {self.host} < {file.name}"
else:
cmd = (
f"docker exec -i {self.target_database} mysql -u {self.user} "
+ f"--password=tutorial < {file.name}"
)
cmd = (
f"mysql -p -h {self.host} < {file.name}"
if self.target_database == "mysql"
else f"docker exec -i {self.target_database} mysql -u {self.user} "
+ f"--password=tutorial < {file.name}"
)
os.system(cmd)
9 changes: 9 additions & 0 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datajoint.utils import get_master, user_choice
from pymysql.err import DataError

from spyglass.utils.database_settings import SHARED_MODULES
from spyglass.utils.dj_chains import TableChain, TableChains
from spyglass.utils.dj_helper_fn import fetch_nwb
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
Expand Down Expand Up @@ -55,6 +56,14 @@ class SpyglassMixin:
_session_pk = None # Session primary key. Mixin is ambivalent to Session pk
_member_pk = None # LabMember primary key. Mixin ambivalent table structure

def __init__(self, *args, **kwargs):
"""Initialize SpyglassMixin.
Checks that schema prefix is in SHARED_MODULES.
"""
if self.database and self.database.split("_")[0] not in SHARED_MODULES:
raise ValueError(f"Database {self.database} not in SHARED_MODULES")

# ------------------------------- fetch_nwb -------------------------------

@cached_property
Expand Down

0 comments on commit 472d59a

Please sign in to comment.