Skip to content

Commit

Permalink
add doownload file feature
Browse files Browse the repository at this point in the history
Signed-off-by: lanzhiwang <[email protected]>
  • Loading branch information
lanzhiwang committed May 6, 2023
1 parent d6f4816 commit 9bf8654
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 146 deletions.
3 changes: 1 addition & 2 deletions jupyterlab_s3_browser/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def _fetchVersion():
pass

raise FileNotFoundError( # noqa: F821
"Could not find package.json under dir {}".format(HERE)
)
"Could not find package.json under dir {}".format(HERE))


__version__ = _fetchVersion()
112 changes: 81 additions & 31 deletions jupyterlab_s3_browser/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
import base64
import json
import logging
from pathlib import Path

import boto3
import s3fs
import tornado
from botocore.exceptions import NoCredentialsError
from jupyter_server.base.handlers import APIHandler
from jupyter_server.utils import url_path_join
from pathlib import Path

import s3fs
import boto3


class DirectoryNotEmptyException(Exception):
"""Raise for attempted deletions of non-empty directories"""

pass


Expand Down Expand Up @@ -48,7 +49,7 @@ def create_s3_resource(config):
)

else:
return boto3.resource("s3")
return boto3.resource('s3')


def _test_aws_s3_role_access():
Expand All @@ -57,10 +58,11 @@ def _test_aws_s3_role_access():
"""
test = boto3.resource("s3")
all_buckets = test.buckets.all()
result = [
{"name": bucket.name + "/", "path": bucket.name + "/", "type": "directory"}
for bucket in all_buckets
]
result = [{
"name": bucket.name + "/",
"path": bucket.name + "/",
"type": "directory"
} for bucket in all_buckets]
return result


Expand All @@ -79,8 +81,7 @@ def has_aws_s3_role_access():
access_key_id = line.split("=", 1)[1]
# aws keys reliably start with AKIA for long-term or ASIA for short-term
if not access_key_id.startswith(
"AKIA"
) and not access_key_id.startswith("ASIA"):
"AKIA") and not access_key_id.startswith("ASIA"):
# if any keys are not valid AWS keys, don't try to authenticate
logging.info(
"Found invalid AWS aws_access_key_id in ~/.aws/credentials file, "
Expand Down Expand Up @@ -111,12 +112,11 @@ def test_s3_credentials(endpoint_url, client_id, client_secret, session_token):
aws_session_token=session_token,
)
all_buckets = test.buckets.all()
logging.debug(
[
{"name": bucket.name + "/", "path": bucket.name + "/", "type": "directory"}
for bucket in all_buckets
]
)
logging.debug([{
"name": bucket.name + "/",
"path": bucket.name + "/",
"type": "directory"
} for bucket in all_buckets])


class AuthHandler(APIHandler): # pylint: disable=abstract-method
Expand Down Expand Up @@ -177,7 +177,8 @@ def post(self, path=""):
client_secret = req["client_secret"]
session_token = req["session_token"]

test_s3_credentials(endpoint_url, client_id, client_secret, session_token)
test_s3_credentials(endpoint_url, client_id, client_secret,
session_token)

self.config.endpoint_url = endpoint_url
self.config.client_id = client_id
Expand All @@ -202,7 +203,51 @@ def convertS3FStoJupyterFormat(result):
}


class S3Handler(APIHandler):
class FilesHandler(APIHandler):
"""
Handles requests for getting files (e.g. for downloading)
"""

@property
def config(self):
return self.settings["s3_config"]

@tornado.web.authenticated
def get(self, path=""):
"""
Takes a path and returns lists of files/objects
and directories/prefixes based on the path.
"""
path = path.removeprefix("/")

try:
if not self.s3fs:
self.s3fs = create_s3fs(self.config)

self.s3fs.invalidate_cache()

with self.s3fs.open(path, "rb") as f:
result = f.read()

except S3ResourceNotFoundException as e:
result = json.dumps({
"error":
404,
"message":
"The requested resource could not be found.",
})
except Exception as e:
logging.error("Exception encountered during GET {}: {}".format(
path, e))
result = json.dumps({"error": 500, "message": str(e)})

self.finish(result)

s3fs = None
s3_resource = None


class ContentsHandler(APIHandler):
"""
Handles requests for getting S3 objects
"""
Expand Down Expand Up @@ -230,18 +275,18 @@ def get(self, path=""):
self.s3fs.invalidate_cache()

if (path and not path.endswith("/")) and (
"X-Custom-S3-Is-Dir" not in self.request.headers
"X-Custom-S3-Is-Dir" not in self.request.headers
): # TODO: replace with function
with self.s3fs.open(path, "rb") as f:
result = {
"path": path,
"type": "file",
"content": base64.encodebytes(f.read()).decode("ascii"),
"content":
base64.encodebytes(f.read()).decode("ascii"),
}
else:
raw_result = list(
map(convertS3FStoJupyterFormat, self.s3fs.listdir(path))
)
map(convertS3FStoJupyterFormat, self.s3fs.listdir(path)))
result = list(filter(lambda x: x["name"] != "", raw_result))

except S3ResourceNotFoundException as e:
Expand All @@ -250,7 +295,8 @@ def get(self, path=""):
"message": "The requested resource could not be found.",
}
except Exception as e:
logging.error("Exception encountered during GET {}: {}".format(path, e))
logging.error("Exception encountered during GET {}: {}".format(
path, e))
result = {"error": 500, "message": str(e)}

self.finish(json.dumps(result))
Expand Down Expand Up @@ -283,7 +329,8 @@ def put(self, path=""):
result = {
"path": path,
"type": "file",
"content": base64.encodebytes(f.read()).decode("ascii"),
"content":
base64.encodebytes(f.read()).decode("ascii"),
}
elif "X-Custom-S3-Move-Src" in self.request.headers:
source = self.request.headers["X-Custom-S3-Move-Src"]
Expand All @@ -295,7 +342,8 @@ def put(self, path=""):
result = {
"path": path,
"type": "file",
"content": base64.encodebytes(f.read()).decode("ascii"),
"content":
base64.encodebytes(f.read()).decode("ascii"),
}
elif "X-Custom-S3-Is-Dir" in self.request.headers:
path = path.lower()
Expand Down Expand Up @@ -351,14 +399,12 @@ def delete(self, path=""):
objects_matching_prefix = self.s3fs.listdir(path + "/")
is_directory = (len(objects_matching_prefix) > 1) or (
(len(objects_matching_prefix) == 1)
and objects_matching_prefix[0]["Key"] != path
)
and objects_matching_prefix[0]['Key'] != path)

if is_directory:
if (len(objects_matching_prefix) > 1) or (
(len(objects_matching_prefix) == 1)
and objects_matching_prefix[0]["Key"] != path + "/"
):
and objects_matching_prefix[0]['Key'] != path + "/"):
raise DirectoryNotEmptyException()
else:
# for some reason s3fs.rm doesn't work reliably
Expand Down Expand Up @@ -393,7 +439,11 @@ def setup_handlers(web_app):

base_url = web_app.settings["base_url"]
handlers = [
(url_path_join(base_url, "jupyterlab_s3_browser", "auth(.*)"), AuthHandler),
(url_path_join(base_url, "jupyterlab_s3_browser", "files(.*)"), S3Handler),
(url_path_join(base_url, "jupyterlab_s3_browser",
"auth(.*)"), AuthHandler),
(url_path_join(base_url, "jupyterlab_s3_browser",
"contents(.*)"), ContentsHandler),
(url_path_join(base_url, "jupyterlab_s3_browser",
"files(.*)"), FilesHandler),
]
web_app.add_handlers(host_pattern, handlers)
21 changes: 13 additions & 8 deletions jupyterlab_s3_browser/tests/test_get_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def test_get_single_bucket():
s3.create_bucket(Bucket=bucket_name)

result = jupyterlab_s3_browser.get_s3_objects_from_path(s3, "/")
assert result == [{"name": bucket_name, "type": "directory", "path": bucket_name}]
assert result == [{
"name": bucket_name,
"type": "directory",
"path": bucket_name
}]


@mock_s3
Expand All @@ -24,10 +28,11 @@ def test_get_multiple_buckets():
s3.create_bucket(Bucket=bucket_name)

result = jupyterlab_s3_browser.get_s3_objects_from_path(s3, "/")
expected_result = [
{"name": bucket_name, "type": "directory", "path": bucket_name}
for bucket_name in bucket_names
]
expected_result = [{
"name": bucket_name,
"type": "directory",
"path": bucket_name
} for bucket_name in bucket_names]
assert result == expected_result


Expand Down Expand Up @@ -60,6 +65,6 @@ def test_get_files_inside_bucket():
},
]
print(result)
assert sorted(result, key=lambda i: i["name"]) == sorted(
expected_result, key=lambda i: i["name"]
)
assert sorted(result,
key=lambda i: i["name"]) == sorted(expected_result,
key=lambda i: i["name"])
14 changes: 8 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

data_files_spec = [
("share/jupyter/labextensions/%s" % labext_name, str(lab_path), "**"),
("share/jupyter/labextensions/%s" % labext_name, str(HERE), "install.json"),
("share/jupyter/labextensions/%s" % labext_name, str(HERE),
"install.json"),
(
"etc/jupyter/jupyter_server_config.d",
"jupyter-config/jupyter_server_config.d",
Expand All @@ -50,10 +51,9 @@
),
]


cmdclass = create_cmdclass(
"jsdeps", package_data_spec=package_data_spec, data_files_spec=data_files_spec
)
cmdclass = create_cmdclass("jsdeps",
package_data_spec=package_data_spec,
data_files_spec=data_files_spec)

js_command = combine_commands(
install_npm(HERE, build_cmd="build:prod", npm=["jlpm"]),
Expand Down Expand Up @@ -99,7 +99,9 @@
"singleton-decorator",
"jupyterlab>=2.0.0",
],
extras_require={"dev": ["jupyter_packaging~=0.7.9", "pytest", "moto", "coverage"]},
extras_require={
"dev": ["jupyter_packaging~=0.7.9", "pytest", "moto", "coverage"]
},
)

if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 9bf8654

Please sign in to comment.