Skip to content

Commit

Permalink
fix(sql): only return tables in current_database (#9748)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <[email protected]>
  • Loading branch information
gforsyth and cpcloud authored Aug 5, 2024
1 parent 2e76af6 commit c7f5717
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 89 deletions.
3 changes: 1 addition & 2 deletions docker/mysql/startup.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
CREATE USER 'ibis'@'localhost' IDENTIFIED BY 'ibis';
CREATE SCHEMA IF NOT EXISTS test_schema;
GRANT CREATE, DROP ON *.* TO 'ibis'@'%';
GRANT CREATE,SELECT,DROP ON `test_schema`.* TO 'ibis'@'%';
GRANT CREATE,SELECT,DROP ON *.* TO 'ibis'@'%';
FLUSH PRIVILEGES;
42 changes: 15 additions & 27 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,41 +337,29 @@ def get_schema(
-------
sch.Schema
Ibis schema
"""
conditions = [sg.column("table_name").eq(sge.convert(table_name))]

if catalog is not None:
conditions.append(sg.column("table_catalog").eq(sge.convert(catalog)))

if database is not None:
conditions.append(sg.column("table_schema").eq(sge.convert(database)))

query = (
sg.select(
"column_name",
"data_type",
sg.column("is_nullable").eq(sge.convert("YES")).as_("nullable"),
query = sge.Describe(
this=sg.table(
table_name, db=database, catalog=catalog, quoted=self.compiler.quoted
)
.from_(sg.table("columns", db="information_schema"))
.where(sg.and_(*conditions))
.order_by("ordinal_position")
)
).sql(self.dialect)

with self._safe_raw_sql(query) as cur:
meta = cur.fetch_arrow_table()

if not meta:
try:
result = self.con.sql(query)
except duckdb.CatalogException:
raise exc.IbisError(f"Table not found: {table_name!r}")
else:
meta = result.fetch_arrow_table()

names = meta["column_name"].to_pylist()
types = meta["data_type"].to_pylist()
nullables = meta["nullable"].to_pylist()
types = meta["column_type"].to_pylist()
nullables = meta["null"].to_pylist()

type_mapper = self.compiler.type_mapper
return sch.Schema(
{
name: self.compiler.type_mapper.from_string(typ, nullable=nullable)
for name, typ, nullable in zip(names, types, nullables)
name: type_mapper.from_string(typ, nullable=null == "YES")
for name, typ, null in zip(names, types, nullables)
}
)

Expand Down Expand Up @@ -512,7 +500,7 @@ def _load_extensions(
query = (
sg.select(f.anon.unnest(f.list_append(C.aliases, C.extension_name)))
.from_(f.duckdb_extensions())
.where(sg.and_(C.installed, C.loaded))
.where(C.installed, C.loaded)
)
with self._safe_raw_sql(query) as cur:
installed = map(itemgetter(0), cur.fetchall())
Expand Down
30 changes: 30 additions & 0 deletions ibis/backends/duckdb/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,33 @@ def test_hugging_face(con, url, method_name):
method = getattr(con, method_name)
t = method(url)
assert t.count().execute() > 0


def test_multiple_tables_with_the_same_name(tmp_path):
# check within the same database
path = tmp_path / "test1.ddb"
with duckdb.connect(str(path)) as con:
con.execute("CREATE TABLE t (x INT)")
con.execute("CREATE SCHEMA s")
con.execute("CREATE TABLE s.t (y STRING)")

con = ibis.duckdb.connect(path)
t1 = con.table("t")
t2 = con.table("t", database="s")
assert t1.schema() == ibis.schema({"x": "int32"})
assert t2.schema() == ibis.schema({"y": "string"})

path = tmp_path / "test2.ddb"
with duckdb.connect(str(path)) as c:
c.execute("CREATE TABLE t (y DOUBLE[])")

# attach another catalog and check that too
con.attach(path, name="w")
t1 = con.table("t")
t2 = con.table("t", database="s")
assert t1.schema() == ibis.schema({"x": "int32"})
assert t2.schema() == ibis.schema({"y": "string"})

t3 = con.table("t", database="w.main")

assert t3.schema() == ibis.schema({"y": "array<float64>"})
21 changes: 11 additions & 10 deletions ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,9 @@ def create_table(

if temp:
raise com.UnsupportedOperationError(
"Creating temp tables is not supported by Exasol."
f"Creating temp tables is not supported by {self.name}"
)

if database is not None and database != self.current_database:
raise com.UnsupportedOperationError(
"Creating tables in other databases is not supported by Exasol"
)
else:
database = None

quoted = self.compiler.quoted

temp_memtable_view = None
Expand Down Expand Up @@ -435,7 +428,11 @@ def drop_database(
raise NotImplementedError(
"`catalog` argument is not supported for the Exasol backend"
)
drop_schema = sg.exp.Drop(kind="SCHEMA", this=name, exists=force)
drop_schema = sg.exp.Drop(
kind="SCHEMA",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
exists=force,
)
with self.begin() as con:
con.execute(drop_schema.sql(dialect=self.dialect))

Expand All @@ -446,7 +443,11 @@ def create_database(
raise NotImplementedError(
"`catalog` argument is not supported for the Exasol backend"
)
create_database = sg.exp.Create(kind="SCHEMA", this=name, exists=force)
create_database = sg.exp.Create(
kind="SCHEMA",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
exists=force,
)
open_database = self.current_database
with self.begin() as con:
con.execute(create_database.sql(dialect=self.dialect))
Expand Down
33 changes: 13 additions & 20 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,15 @@ def get_schema(
if name.startswith("ibis_cache_"):
catalog, database = ("tempdb", "dbo")
name = "##" + name
conditions = [sg.column("table_name").eq(sge.convert(name))]

if database is not None:
conditions.append(sg.column("table_schema").eq(sge.convert(database)))

query = (
sg.select(
"column_name",
"data_type",
"is_nullable",
"numeric_precision",
"numeric_scale",
"datetime_precision",
C.column_name,
C.data_type,
C.is_nullable,
C.numeric_precision,
C.numeric_scale,
C.datetime_precision,
)
.from_(
sg.table(
Expand All @@ -199,8 +195,11 @@ def get_schema(
catalog=catalog or self.current_catalog,
)
)
.where(*conditions)
.order_by("ordinal_position")
.where(
C.table_name.eq(sge.convert(name)),
C.table_schema.eq(sge.convert(database or self.current_database)),
)
.order_by(C.ordinal_position)
)

with self._safe_raw_sql(query) as cur:
Expand Down Expand Up @@ -487,26 +486,20 @@ def list_tables(
"""
table_loc = self._warn_and_create_table_loc(database, schema)
catalog, db = self._to_catalog_db_tuple(table_loc)
conditions = []

if db:
conditions.append(C.table_schema.eq(sge.convert(db)))

sql = (
sg.select("table_name")
sg.select(C.table_name)
.from_(
sg.table(
"TABLES",
db="INFORMATION_SCHEMA",
catalog=catalog if catalog is not None else self.current_catalog,
)
)
.where(C.table_schema.eq(sge.convert(db or self.current_database)))
.distinct()
)

if conditions:
sql = sql.where(*conditions)

sql = sql.sql(self.dialect)

with self._safe_raw_sql(sql) as cur:
Expand Down
7 changes: 0 additions & 7 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,6 @@ def create_table(
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")

if database is not None and database != self.current_database:
raise com.UnsupportedOperationError(
"Creating tables in other databases is not supported by Postgres"
)
else:
database = None

properties = []

if temp:
Expand Down
28 changes: 26 additions & 2 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,20 @@ def _post_connect(self) -> None:
with self.begin() as cur:
cur.execute("SET TIMEZONE = UTC")

@property
def _session_temp_db(self) -> str | None:
# Postgres doesn't assign the temporary table database until the first
# temp table is created in a given session.
# Before that temp table is created, this will return `None`
# After a temp table is created, it will return `pg_temp_N` where N is
# some integer
res = self.raw_sql(
"select nspname from pg_namespace where oid = pg_my_temp_schema()"
).fetchone()
if res is not None:
return res[0]
return res

def list_tables(
self,
like: str | None = None,
Expand Down Expand Up @@ -458,7 +472,7 @@ def function(self, name: str, *, database: str | None = None) -> Callable:
on=n.oid.eq(p.pronamespace),
join_type="LEFT",
)
.where(sg.and_(*predicates))
.where(*predicates)
)

def split_name_type(arg: str) -> tuple[str, dt.DataType]:
Expand Down Expand Up @@ -571,6 +585,16 @@ def get_schema(

format_type = self.compiler.f["pg_catalog.format_type"]

# If no database is specified, assume the current database
db = database or self.current_database

dbs = [sge.convert(db)]

# If a database isn't specified, then include temp tables in the
# returned values
if database is None and (temp_table_db := self._session_temp_db) is not None:
dbs.append(sge.convert(temp_table_db))

type_info = (
sg.select(
a.attname.as_("column_name"),
Expand All @@ -591,7 +615,7 @@ def get_schema(
.where(
a.attnum > 0,
sg.not_(a.attisdropped),
n.nspname.eq(sge.convert(database)) if database is not None else TRUE,
n.nspname.isin(*dbs),
c.relname.eq(sge.convert(name)),
)
.order_by(a.attnum)
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,9 @@ def drop_sink(
)
with self._safe_raw_sql(src):
pass

@property
def _session_temp_db(self) -> str | None:
# Return `None`, because RisingWave does not implement temp tables like
# Postgres
return None
7 changes: 7 additions & 0 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,13 @@ def get_schema(
catalog: str | None = None,
database: str | None = None,
) -> Iterable[tuple[str, dt.DataType]]:
# this will always show temp tables with the same name as a non-temp
# table first
#
# snowflake puts temp tables in the same catalog and database as
# non-temp tables and differentiates between them using a different
# mechanism than other database that often put temp tables in a hidden
# or intentionall-difficult-to-access catalog/database
table = sg.table(
table_name, db=database, catalog=catalog, quoted=self.compiler.quoted
)
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@

try:
from google.api_core.exceptions import BadRequest as GoogleBadRequest
from google.api_core.exceptions import NotFound as GoogleNotFound
except ImportError:
GoogleBadRequest = None
GoogleBadRequest = GoogleNotFound = None

try:
from polars.exceptions import ColumnNotFoundError as PolarsColumnNotFoundError
Expand Down
Loading

0 comments on commit c7f5717

Please sign in to comment.