Skip to content

Commit

Permalink
feat: Allow extension-less file downloads. #1627 (#1635)
Browse files Browse the repository at this point in the history
  • Loading branch information
mturoci authored Dec 9, 2022
1 parent 2b16927 commit 39e1af1
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
26 changes: 13 additions & 13 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,6 @@
"bdist_wheel",
]
},
{
"name": "Debug PY App",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/py/venv/bin/wave",
"python": "${workspaceFolder}/py/venv/bin/python",
"cwd": "${workspaceFolder}/py/examples",
"args": [
"run",
"tour",
"--no-reload",
],
},
{
"name": "Debug Py Tests",
"type": "python",
Expand Down Expand Up @@ -104,6 +91,19 @@
"${workspaceFolder}/tools/wavegen/build/**/*.js"
]
},
{
"name": "Debug Wave App",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/py/venv/bin/wave",
"python": "${workspaceFolder}/py/venv/bin/python",
"cwd": "${workspaceFolder}/py/examples",
"args": [
"run",
"tour",
"--no-reload",
],
},
{
"name": "Debug Wave Server",
"type": "go",
Expand Down
7 changes: 5 additions & 2 deletions file_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ func (fs *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

if path.Ext(r.URL.Path) == "" { // ignore requests for directories and ext-less files
trimmedPrefix := strings.TrimPrefix(r.URL.Path, fs.baseURL)
fsDirPath := path.Join(fs.dir, trimmedPrefix)
// Ignore requests for directories and non-existent / unaccessible files.
if fileInfo, err := os.Stat(filepath.FromSlash(fsDirPath)); err != nil || fileInfo.IsDir() {
echo(Log{"t": "file_download", "path": r.URL.Path, "error": "not found"})
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}

echo(Log{"t": "file_download", "path": r.URL.Path})
r.URL.Path = strings.TrimPrefix(r.URL.Path, fs.baseURL) // public
r.URL.Path = trimmedPrefix // public
fs.handler.ServeHTTP(w, r)

case http.MethodPost:
Expand Down
19 changes: 12 additions & 7 deletions py/h2o_wave/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,9 +719,12 @@ def download(self, url: str, path: str) -> str:
# If path is a directory, get basename from url
filepath = os.path.join(path, os.path.basename(url)) if os.path.isdir(path) else path

with open(filepath, 'wb') as f:
with self._http.stream('GET', f'{_config.hub_host_address}{url}') as r:
for chunk in r.iter_bytes():
with self._http.stream('GET', f'{_config.hub_host_address}{url}') as res:
if res.status_code != 200:
res.read()
raise ServiceError(f'Download failed (code={res.status_code}): {res.text}')
with open(filepath, 'wb') as f:
for chunk in res.iter_bytes():
f.write(chunk)

return filepath
Expand Down Expand Up @@ -893,10 +896,12 @@ async def download(self, url: str, path: str) -> str:
path = os.path.abspath(path)
# If path is a directory, get basename from url
filepath = os.path.join(path, os.path.basename(url)) if os.path.isdir(path) else path

with open(filepath, 'wb') as f:
async with self._http.stream('GET', f'{_config.hub_host_address}{url}') as r:
async for chunk in r.aiter_bytes():
async with self._http.stream('GET', f'{_config.hub_host_address}{url}') as res:
if res.status_code != 200:
await res.aread()
raise ServiceError(f'Download failed (code={res.status_code}): {res.text}')
with open(filepath, 'wb') as f:
async for chunk in res.aiter_bytes():
f.write(chunk)

return filepath
Expand Down
4 changes: 3 additions & 1 deletion py/tests/test_python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

base_url = os.getenv('H2O_WAVE_BASE_URL', '/')


# TODO: Add cleanup (site.unload) to tests that upload files.
class TestPythonServer(unittest.TestCase):
def test_new_empty_card(self):
page = site['/test']
Expand Down Expand Up @@ -383,7 +385,7 @@ def test_multipart_server(self):

def test_upload_dir(self):
upload_path, = site.upload_dir(os.path.join('tests', 'test_folder'))
download_path = site.download(f'{base_url}{upload_path}test.txt', 'test.txt')
download_path = site.download(f'{upload_path}/test.txt', 'test.txt')
txt = read_file(download_path)
os.remove(download_path)
assert len(txt) > 0
Expand Down
5 changes: 2 additions & 3 deletions py/tests/test_python_server_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .utils import read_file


# TODO: Add cleanup (site.unload) to tests that upload files.
class TestPythonServerAsync(unittest.IsolatedAsyncioTestCase):
def __init__(self, methodName: str = ...) -> None:
super().__init__(methodName)
Expand Down Expand Up @@ -68,8 +68,7 @@ async def test_multipart_server(self):

async def test_upload_dir(self):
upload_path, = await self.site.upload_dir(os.path.join('tests', 'test_folder'))
base_url = os.getenv('H2O_WAVE_BASE_URL', '/')
download_path = await self.site.download(f'{base_url}{upload_path}test.txt', 'test.txt')
download_path = await self.site.download(f'{upload_path}/test.txt', 'test.txt')
txt = read_file(download_path)
os.remove(download_path)
assert len(txt) > 0
Expand Down

0 comments on commit 39e1af1

Please sign in to comment.