Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-display user's query with an error message if an error occurs #1346

Merged
merged 13 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datasette/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def _connected_databases(self):
"is_memory": d.is_memory,
"hash": d.hash,
}
for name, d in sorted(self.databases.items(), key=lambda p: p[1].name)
for name, d in self.databases.items()
if name != "_internal"
]

Expand Down
4 changes: 4 additions & 0 deletions datasette/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def convert_specific_columns_to_json(rows, columns, json_cols):
def json_renderer(args, data, view_name):
"""Render a response as JSON"""
status_code = 200

# Handle the _json= parameter which may modify data["rows"]
json_cols = []
if "_json" in args:
Expand All @@ -44,6 +45,9 @@ def json_renderer(args, data, view_name):

# Deal with the _shape option
shape = args.get("_shape", "arrays")
# if there's an error, ignore the shape entirely
if data.get("error"):
shape = "arrays"

next_url = data.get("next_url")

Expand Down
5 changes: 4 additions & 1 deletion datasette/templates/query.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ <h1 style="padding-left: 10px; border-left: 10px solid #{{ database_color(databa
{% block description_source_license %}{% include "_description_source_license.html" %}{% endblock %}

<form class="sql" action="{{ urls.database(database) }}{% if canned_query %}/{{ canned_query }}{% endif %}" method="{% if canned_write %}post{% else %}get{% endif %}">
<h3>Custom SQL query{% if display_rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(display_rows|length) }} row{% if display_rows|length == 1 %}{% else %}s{% endif %}{% endif %} <span class="show-hide-sql">{% if hide_sql %}(<a href="{{ path_with_removed_args(request, {'_hide_sql': '1'}) }}">show</a>){% else %}(<a href="{{ path_with_added_args(request, {'_hide_sql': '1'}) }}">hide</a>){% endif %}</span></h3>
<h3>Custom SQL query{% if display_rows %} returning {% if truncated %}more than {% endif %}{{ "{:,}".format(display_rows|length) }} row{% if display_rows|length == 1 %}{% else %}s{% endif %}{% endif %}{% if not query_error %} <span class="show-hide-sql">{% if hide_sql %}(<a href="{{ path_with_removed_args(request, {'_hide_sql': '1'}) }}">show</a>){% else %}(<a href="{{ path_with_added_args(request, {'_hide_sql': '1'}) }}">hide</a>){% endif %}</span>{% endif %}</h3>
{% if query_error %}
<p class="message-error">{{ query_error }}</p>
{% endif %}
{% if not hide_sql %}
{% if editable and allow_execute_sql %}
<p><textarea id="sql-editor" name="sql">{% if query and query.sql %}{{ query.sql }}{% else %}select * from {{ tables[0].name|escape_sqlite }}{% endif %}</textarea></p>
Expand Down
11 changes: 10 additions & 1 deletion datasette/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import itertools
import json
import markupsafe
import mergedeep
import os
import re
Expand Down Expand Up @@ -546,7 +547,7 @@ def detect_fts_sql(table):
)
)
""".format(
table=table
table=table.replace("'", "''")
)


Expand Down Expand Up @@ -777,6 +778,14 @@ async def write(self, bytes):
await self.writer.write(bytes)


class EscapeHtmlWriter:
def __init__(self, writer):
self.writer = writer

async def write(self, content):
await self.writer.write(markupsafe.escape(content))


_infinities = {float("inf"), float("-inf")}


Expand Down
71 changes: 62 additions & 9 deletions datasette/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datasette.database import QueryInterrupted
from datasette.utils import (
await_me_maybe,
EscapeHtmlWriter,
InvalidSql,
LimitedWriter,
call_with_supported_arguments,
Expand Down Expand Up @@ -262,6 +263,23 @@ async def get(self, request, db_name, **kwargs):

async def as_csv(self, request, database, hash, **kwargs):
stream = request.args.get("_stream")
# Do not calculate facets or counts:
extra_parameters = [
"{}=1".format(key)
for key in ("_nofacet", "_nocount")
if not request.args.get(key)
]
if extra_parameters:
if not request.query_string:
new_query_string = "&".join(extra_parameters)
else:
new_query_string = (
request.query_string + "&" + "&".join(extra_parameters)
)
new_scope = dict(
request.scope, query_string=new_query_string.encode("latin-1")
)
request.scope = new_scope
if stream:
# Some quick sanity checks
if not self.ds.setting("allow_csv_stream"):
Expand All @@ -276,6 +294,8 @@ async def as_csv(self, request, database, hash, **kwargs):
)
if isinstance(response_or_template_contexts, Response):
return response_or_template_contexts
elif len(response_or_template_contexts) == 4:
data, _, _, _ = response_or_template_contexts
else:
data, _, _ = response_or_template_contexts
except (sqlite3.OperationalError, InvalidSql) as e:
Expand All @@ -298,9 +318,27 @@ async def as_csv(self, request, database, hash, **kwargs):
if column in expanded_columns:
headings.append(f"{column}_label")

content_type = "text/plain; charset=utf-8"
preamble = ""
postamble = ""

trace = request.args.get("_trace")
if trace:
content_type = "text/html; charset=utf-8"
preamble = (
"<html><head><title>CSV debug</title></head>"
'<body><textarea style="width: 90%; height: 70vh">'
)
postamble = "</textarea></body></html>"

async def stream_fn(r):
nonlocal data
writer = csv.writer(LimitedWriter(r, self.ds.setting("max_csv_mb")))
nonlocal data, trace
limited_writer = LimitedWriter(r, self.ds.setting("max_csv_mb"))
if trace:
await limited_writer.write(preamble)
writer = csv.writer(EscapeHtmlWriter(limited_writer))
else:
writer = csv.writer(limited_writer)
first = True
next = None
while first or (next and stream):
Expand Down Expand Up @@ -333,7 +371,7 @@ async def stream_fn(r):
)
else:
# Otherwise generate URL for this query
cell = self.ds.absolute_url(
url = self.ds.absolute_url(
request,
path_with_format(
request=request,
Expand All @@ -347,6 +385,9 @@ async def stream_fn(r):
replace_format="csv",
),
)
cell = url.replace("&_nocount=1", "").replace(
"&_nofacet=1", ""
)
new_row.append(cell)
row = new_row
if not expanded_columns:
Expand All @@ -371,13 +412,14 @@ async def stream_fn(r):
sys.stderr.flush()
await r.write(str(e))
return
await limited_writer.write(postamble)

content_type = "text/plain; charset=utf-8"
headers = {}
if self.ds.cors:
headers["Access-Control-Allow-Origin"] = "*"
if request.args.get("_dl", None):
content_type = "text/csv; charset=utf-8"
if not trace:
content_type = "text/csv; charset=utf-8"
disposition = 'attachment; filename="{}.csv"'.format(
kwargs.get("table", database)
)
Expand Down Expand Up @@ -427,15 +469,22 @@ async def view_get(self, request, database, hash, correct_hash_provided, **kwarg

extra_template_data = {}
start = time.perf_counter()
status_code = 200
status_code = None
templates = []
try:
response_or_template_contexts = await self.data(
request, database, hash, **kwargs
)
if isinstance(response_or_template_contexts, Response):
return response_or_template_contexts

# If it has four items, it includes an HTTP status code
if len(response_or_template_contexts) == 4:
(
data,
extra_template_data,
templates,
status_code,
) = response_or_template_contexts
else:
data, extra_template_data, templates = response_or_template_contexts
except QueryInterrupted:
Expand Down Expand Up @@ -502,12 +551,15 @@ async def view_get(self, request, database, hash, correct_hash_provided, **kwarg
if isinstance(result, dict):
r = Response(
body=result.get("body"),
status=result.get("status_code", 200),
status=result.get("status_code", status_code or 200),
content_type=result.get("content_type", "text/plain"),
headers=result.get("headers"),
)
elif isinstance(result, Response):
r = result
if status_code is not None:
# Over-ride the status code
r.status = status_code
else:
assert False, f"{result} should be dict or Response"
else:
Expand Down Expand Up @@ -567,7 +619,8 @@ async def view_get(self, request, database, hash, correct_hash_provided, **kwarg
if "metadata" not in context:
context["metadata"] = self.ds.metadata
r = await self.render(templates, request=request, context=context)
r.status = status_code
if status_code is not None:
r.status = status_code

ttl = request.args.get("_ttl", None)
if ttl is None or not ttl.isdigit():
Expand Down
25 changes: 18 additions & 7 deletions datasette/views/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
path_with_added_args,
path_with_format,
path_with_removed_args,
sqlite3,
InvalidSql,
)
from datasette.utils.asgi import AsgiFileDownload, Response, Forbidden
Expand Down Expand Up @@ -239,6 +240,8 @@ async def data(

templates = [f"query-{to_css_class(database)}.html", "query.html"]

query_error = None

# Execute query - as write or as read
if write:
if request.method == "POST":
Expand Down Expand Up @@ -320,10 +323,15 @@ async def extra_template():
params_for_query = MagicParameters(params, request, self.ds)
else:
params_for_query = params
results = await self.ds.execute(
database, sql, params_for_query, truncate=True, **extra_args
)
columns = [r[0] for r in results.description]
try:
results = await self.ds.execute(
database, sql, params_for_query, truncate=True, **extra_args
)
columns = [r[0] for r in results.description]
except sqlite3.DatabaseError as e:
query_error = e
results = None
columns = []

if canned_query:
templates.insert(
Expand All @@ -337,7 +345,7 @@ async def extra_template():

async def extra_template():
display_rows = []
for row in results.rows:
for row in results.rows if results else []:
display_row = []
for column, value in zip(results.columns, row):
display_value = value
Expand Down Expand Up @@ -423,17 +431,20 @@ async def extra_template():

return (
{
"ok": not query_error,
"database": database,
"query_name": canned_query,
"rows": results.rows,
"truncated": results.truncated,
"rows": results.rows if results else [],
"truncated": results.truncated if results else False,
"columns": columns,
"query": {"sql": sql, "params": params},
"error": str(query_error) if query_error else None,
"private": private,
"allow_execute_sql": allow_execute_sql,
},
extra_template,
templates,
400 if query_error else 200,
)


Expand Down
Loading