diff --git a/Makefile b/Makefile index 35051be9c1..53f3de3862 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ install-poetry: pip install poetry==1.8.2 install-dependencies: - poetry install -E pyarrow -E hive -E s3fs -E glue -E adlfs -E duckdb -E ray -E sql-postgres -E gcsfs -E sql-sqlite -E daft + poetry install -E pyarrow -E hive -E s3fs -E glue -E adlfs -E duckdb -E ray -E sql-postgres -E gcsfs -E sql-sqlite -E daft -E snowflake install: | install-poetry install-dependencies diff --git a/poetry.lock b/poetry.lock index 76cf30e045..d69bd6da33 100644 --- a/poetry.lock +++ b/poetry.lock @@ -193,6 +193,17 @@ files = [ {file = "antlr4_python3_runtime-4.13.1-py3-none-any.whl", hash = "sha256:78ec57aad12c97ac039ca27403ad61cb98aaec8a3f9bb8144f889aa0fa28b943"}, ] +[[package]] +name = "asn1crypto" +version = "1.5.1" +description = "Fast ASN.1 parser and serializer with definitions for private keys, public keys, certificates, CRL, OCSP, CMS, PKCS#3, PKCS#7, PKCS#8, PKCS#12, PKCS#5, X.509 and TSP" +optional = true +python-versions = "*" +files = [ + {file = "asn1crypto-1.5.1-py2.py3-none-any.whl", hash = "sha256:db4e40728b728508912cbb3d44f19ce188f218e9eba635821bb4b68564f8fd67"}, + {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, +] + [[package]] name = "async-timeout" version = "4.0.3" @@ -3169,6 +3180,24 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pyopenssl" +version = "24.1.0" +description = "Python wrapper module around the OpenSSL library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "pyOpenSSL-24.1.0-py3-none-any.whl", hash = "sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad"}, + {file = "pyOpenSSL-24.1.0.tar.gz", hash = "sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f"}, +] + +[package.dependencies] +cryptography = ">=41.0.5,<43" + +[package.extras] +docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"] +test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"] + [[package]] name = "pyparsing" version = "3.1.2" @@ -3893,6 +3922,65 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "snowflake-connector-python" +version = "3.10.0" +description = "Snowflake Connector for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "snowflake_connector_python-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e2afca4bca70016519d1a7317c498f1d9c56140bf3e40ea40bddcc95fe827ca"}, + {file = "snowflake_connector_python-3.10.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:d19bde29f89b226eb22af4c83134ecb5c229da1d5e960a01b8f495df78dcdc36"}, + {file = "snowflake_connector_python-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bfe013ed97b4dd2e191fd6770a14030d29dd0108817d6ce76b9773250dd2d560"}, + {file = "snowflake_connector_python-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0917c9f9382d830907e1a18ee1208537b203618700a9c671c2a20167b30f574"}, + {file = "snowflake_connector_python-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:7e828bc99240433e6552ac4cc4e37f223ae5c51c7880458ddb281668503c7491"}, + {file = "snowflake_connector_python-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a0d3d06d758455c50b998eabc1fd972a1f67faa5c85ef250fd5986f5a41aab0b"}, + {file = "snowflake_connector_python-3.10.0-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:4602cb19b204bb03e03d65c6d5328467c9efc0fec53ca56768c3747c8dc8a70f"}, + {file = "snowflake_connector_python-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb1a04b496bbd3e1e2e926df82b2369887b2eea958f535fb934c240bfbabf6c5"}, + {file = "snowflake_connector_python-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c889f9f60f915d657e0a0ad2e6cc52cdcafd9bcbfa95a095aadfd8bcae62b819"}, + {file = "snowflake_connector_python-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:8e441484216ed416a6ed338133e23bd991ac4ba2e46531f4d330f61568c49314"}, + {file = "snowflake_connector_python-3.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bb4aced19053c67513cecc92311fa9d3b507b2277698c8e987d404f6f3a49fb2"}, + {file = "snowflake_connector_python-3.10.0-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:858315a2feff86213b079c6293ad8d850a778044c664686802ead8bb1337e1bc"}, + {file = "snowflake_connector_python-3.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:adf16e1ca9f46d3bdf68e955ffa42075ebdb251e3b13b59003d04e4fea7d579a"}, + {file = "snowflake_connector_python-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4c5c2a08b39086a5348502652ad4fdf24871d7ab30fd59f6b7b57249158468c"}, + {file = "snowflake_connector_python-3.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:05011286f42c52eb3e5a6db59ee3eaf79f3039f3a19d7ffac6f4ee143779c637"}, + {file = "snowflake_connector_python-3.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:569301289ada5b0d72d0bd8432b7ca180220335faa6d9a0f7185f60891db6f2c"}, + {file = "snowflake_connector_python-3.10.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:4e5641c70a12da9804b74f350b8cbbdffdc7aca5069b096755abd2a1fdcf5d1b"}, + {file = "snowflake_connector_python-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ff767a1b8c48431549ac28884f8bd9647e63a23f470b05f6ab8d143c4b1475"}, + {file = "snowflake_connector_python-3.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e52bbc1e2e7bda956525b4229d7f87579f8cabd7d5506b12aa754c4bcdc8c8d7"}, + {file = "snowflake_connector_python-3.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:280a8dcca0249e864419564e38764c08f8841900d9872fec2f2855fda494b29f"}, + {file = "snowflake_connector_python-3.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:67bf570230b0cf818e6766c17245c7355a1f5ea27778e54ab8d09e5bb3536ad9"}, + {file = "snowflake_connector_python-3.10.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:aa1e26f9c571d2c4206da5c978c1b345ffd798d3db1f9ae91985e6243c6bf94b"}, + {file = "snowflake_connector_python-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73e9baa531d5156a03bfe5af462cf6193ec2a01cbb575edf7a2dd3b2a35254c7"}, + {file = "snowflake_connector_python-3.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e03361c4749e4d65bf0d223fdea1c2d7a33af53b74e873929a6085d150aff17e"}, + {file = "snowflake_connector_python-3.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:e8cddd4357e70ab55d7aeeed144cbbeb1ff658b563d7d8d307afc06178a367ec"}, + {file = "snowflake_connector_python-3.10.0.tar.gz", hash = "sha256:7c7438e958753bd1174b73581d77c92b0b47a86c38d8ea0ba1ea23c442eb8e75"}, +] + +[package.dependencies] +asn1crypto = ">0.24.0,<2.0.0" +certifi = ">=2017.4.17" +cffi = ">=1.9,<2.0.0" +charset-normalizer = ">=2,<4" +cryptography = ">=3.1.0,<43.0.0" +filelock = ">=3.5,<4" +idna = ">=2.5,<4" +packaging = "*" +platformdirs = ">=2.6.0,<5.0.0" +pyjwt = "<3.0.0" +pyOpenSSL = ">=16.2.0,<25.0.0" +pytz = "*" +requests = "<3.0.0" +sortedcontainers = ">=2.4.0" +tomlkit = "*" +typing-extensions = ">=4.3,<5" +urllib3 = {version = ">=1.21.1,<2.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +development = ["Cython", "coverage", "more-itertools", "numpy (<1.27.0)", "pendulum (!=2.1.1)", "pexpect", "pytest (<7.5.0)", "pytest-cov", "pytest-rerunfailures", "pytest-timeout", "pytest-xdist", "pytzdata"] +pandas = ["pandas (>=1.0.0,<3.0.0)", "pyarrow"] +secure-local-storage = ["keyring (>=23.1.0,<25.0.0)"] + [[package]] name = "sortedcontainers" version = "2.4.0" @@ -4062,6 +4150,17 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "tomlkit" +version = "0.12.4" +description = "Style preserving TOML library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tomlkit-0.12.4-py3-none-any.whl", hash = "sha256:5cd82d48a3dd89dee1f9d64420aa20ae65cfbd00668d6f094d7578a78efbb77b"}, + {file = "tomlkit-0.12.4.tar.gz", hash = "sha256:7ca1cfc12232806517a8515047ba66a19369e71edf2439d0f5824f91032b6cc3"}, +] + [[package]] name = "tqdm" version = "4.66.2" @@ -4120,23 +4219,6 @@ brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotl secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] -[[package]] -name = "urllib3" -version = "2.0.7" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=3.7" -files = [ - {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"}, - {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"}, -] - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] -socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["zstandard (>=0.18.0)"] - [[package]] name = "virtualenv" version = "20.25.0" @@ -4456,6 +4538,7 @@ pyarrow = ["pyarrow"] ray = ["pandas", "pyarrow", "ray"] s3fs = ["s3fs"] snappy = ["python-snappy"] +snowflake = ["snowflake-connector-python"] sql-postgres = ["psycopg2-binary", "sqlalchemy"] sql-sqlite = ["sqlalchemy"] zstandard = ["zstandard"] @@ -4463,4 +4546,4 @@ zstandard = ["zstandard"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "91de7f775ff1499d79db490197eee5aadc7078b5244d86e56d8626c2615645f6" +content-hash = "2cf8462414b7ee3c97034cc41b95abc43996f6c7702d1ac3b74bd85d014e42b5" diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index 18d803fe1c..83fa8faecd 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -100,6 +100,7 @@ class CatalogType(Enum): GLUE = "glue" DYNAMODB = "dynamodb" SQL = "sql" + SNOWFLAKE = "snowflake" def load_rest(name: str, conf: Properties) -> Catalog: @@ -146,12 +147,22 @@ def load_sql(name: str, conf: Properties) -> Catalog: ) from exc +def load_snowflake(name: str, conf: Properties) -> Catalog: + try: + from pyiceberg.catalog.snowflake_catalog import SnowflakeCatalog + + return SnowflakeCatalog(name, **conf) + except ImportError as exc: + raise NotInstalledError("Snowflake support not installed: pip install 'pyiceberg[snowflake]'") from exc + + AVAILABLE_CATALOGS: dict[CatalogType, Callable[[str, Properties], Catalog]] = { CatalogType.REST: load_rest, CatalogType.HIVE: load_hive, CatalogType.GLUE: load_glue, CatalogType.DYNAMODB: load_dynamodb, CatalogType.SQL: load_sql, + CatalogType.SNOWFLAKE: load_snowflake, } diff --git a/pyiceberg/catalog/snowflake_catalog.py b/pyiceberg/catalog/snowflake_catalog.py new file mode 100644 index 0000000000..508da2eaa2 --- /dev/null +++ b/pyiceberg/catalog/snowflake_catalog.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from typing import Iterator, List, Optional, Set, Union + +import pyarrow as pa +from boto3.session import Session +from snowflake.connector import DictCursor, SnowflakeConnection + +from pyiceberg.catalog import MetastoreCatalog, PropertiesUpdateSummary +from pyiceberg.exceptions import NoSuchTableError, TableAlreadyExistsError +from pyiceberg.io import S3_ACCESS_KEY_ID, S3_REGION, S3_SECRET_ACCESS_KEY, S3_SESSION_TOKEN +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import CommitTableRequest, CommitTableResponse, StaticTable, Table, sorting +from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties + + +class SnowflakeCatalog(MetastoreCatalog): + @dataclass(frozen=True, eq=True) + class _SnowflakeIdentifier: + database: str | None + schema: str | None + table: str | None + + def __iter__(self) -> Iterator[str]: + """ + Iterate of the non-None parts of the identifier. + + Returns: + Iterator[str]: Iterator of the non-None parts of the identifier. + """ + yield from filter(None, [self.database, self.schema, self.table]) + + @classmethod + def table_from_string(cls, identifier: str) -> SnowflakeCatalog._SnowflakeIdentifier: + parts = identifier.split(".") + if len(parts) == 1: + return cls(None, None, parts[0]) + elif len(parts) == 2: + return cls(None, parts[0], parts[1]) + elif len(parts) == 3: + return cls(parts[0], parts[1], parts[2]) + + raise ValueError(f"Invalid identifier: {identifier}") + + @classmethod + def schema_from_string(cls, identifier: str) -> SnowflakeCatalog._SnowflakeIdentifier: + parts = identifier.split(".") + if len(parts) == 1: + return cls(None, parts[0], None) + elif len(parts) == 2: + return cls(parts[0], parts[1], None) + + raise ValueError(f"Invalid identifier: {identifier}") + + @property + def table_name(self) -> str: + return ".".join(self) + + @property + def schema_name(self) -> str: + return ".".join(self) + + def __init__(self, name: str, **properties: str): + super().__init__(name, **properties) + + params = { + "user": properties["user"], + "account": properties["account"], + } + + if "authenticator" in properties: + params["authenticator"] = properties["authenticator"] + + if "password" in properties: + params["password"] = properties["password"] + + if "private_key" in properties: + params["private_key"] = properties["private_key"] + + self.connection = SnowflakeConnection(**params) + + def load_table(self, identifier: Union[str, Identifier]) -> Table: + sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string( + identifier if isinstance(identifier, str) else ".".join(identifier) + ) + + metadata_query = "SELECT SYSTEM$GET_ICEBERG_TABLE_INFORMATION(%s) AS METADATA" + + with self.connection.cursor(DictCursor) as cursor: + try: + cursor.execute(metadata_query, (sf_identifier.table_name,)) + metadata = json.loads(cursor.fetchone()["METADATA"])["metadataLocation"] + except Exception as e: + raise NoSuchTableError(f"Table {sf_identifier.table_name} not found") from e + + session = Session() + credentials = session.get_credentials() + current_credentials = credentials.get_frozen_credentials() + + s3_props = { + S3_ACCESS_KEY_ID: current_credentials.access_key, + S3_SECRET_ACCESS_KEY: current_credentials.secret_key, + S3_SESSION_TOKEN: current_credentials.token, + S3_REGION: os.environ.get("AWS_REGION", "us-east-1"), + } + + tbl = StaticTable.from_metadata(metadata, properties=s3_props) + tbl.identifier = tuple(identifier.split(".")) if isinstance(identifier, str) else identifier + tbl.catalog = self + + return tbl + + def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + query = "CREATE ICEBERG TABLE (%s) METADATA_FILE_PATH = (%s)" + sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string( + identifier if isinstance(identifier, str) else ".".join(identifier) + ) + + with self.connection.cursor(DictCursor) as cursor: + try: + cursor.execute(query, (sf_identifier.table_name, metadata_location)) + except Exception as e: + raise TableAlreadyExistsError(f"Table {sf_identifier.table_name} already exists") from e + + return self.load_table(identifier) + + def drop_table(self, identifier: Union[str, Identifier]) -> None: + sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string( + identifier if isinstance(identifier, str) else ".".join(identifier) + ) + + query = "DROP TABLE IF EXISTS (%s)" + + with self.connection.cursor(DictCursor) as cursor: + cursor.execute(query, (sf_identifier.table_name,)) + + def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + sf_from_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string( + from_identifier if isinstance(from_identifier, str) else ".".join(from_identifier) + ) + sf_to_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string( + to_identifier if isinstance(to_identifier, str) else ".".join(to_identifier) + ) + + query = "ALTER TABLE (%s) RENAME TO (%s)" + + with self.connection.cursor(DictCursor) as cursor: + cursor.execute(query, (sf_from_identifier.table_name, sf_to_identifier.table_name)) + + return self.load_table(to_identifier) + + def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse: + raise NotImplementedError + + def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string( + namespace if isinstance(namespace, str) else ".".join(namespace) + ) + + db_query = "CREATE DATABASE IF NOT EXISTS (%s)" + schema_query = "CREATE SCHEMA IF NOT EXISTS (%s)" + + with self.connection.cursor(DictCursor) as cursor: + if sf_identifier.database: + cursor.execute(db_query, (sf_identifier.database,)) + cursor.execute(schema_query, (sf_identifier.schema_name,)) + + def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string( + namespace if isinstance(namespace, str) else ".".join(namespace) + ) + + sf_query = "DROP SCHEMA IF EXISTS (%s)" + db_query = "DROP DATABASE IF EXISTS (%s)" + + with self.connection.cursor(DictCursor) as cursor: + if sf_identifier.database: + cursor.execute(db_query, (sf_identifier.database,)) + cursor.execute(sf_query, (sf_identifier.schema_name,)) + + def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string( + namespace if isinstance(namespace, str) else ".".join(namespace) + ) + + schema_query = "SHOW ICEBERG TABLES IN SCHEMA (%s)" + db_query = "SHOW ICEBERG TABLES IN DATABASE (%s)" + + with self.connection.cursor(DictCursor) as cursor: + if sf_identifier.database: + cursor.execute(db_query, (sf_identifier.database,)) + else: + cursor.execute(schema_query, (sf_identifier.schema,)) + + return [(row["database_name"], row["schema_name"], row["table_name"]) for row in cursor.fetchall()] + + def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + raise NotImplementedError + + def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + raise NotImplementedError + + def update_namespace_properties( + self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + ) -> PropertiesUpdateSummary: + raise NotImplementedError + + def create_table( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, pa.Schema], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: sorting.SortOrder = sorting.UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> Table: + raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index 2682e16173..9d79e93252 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ fsspec = ">=2023.1.0,<2025.1.0" pyparsing = ">=3.1.0,<4.0.0" zstandard = ">=0.13.0,<1.0.0" tenacity = ">=8.2.3,<9.0.0" +snowflake-connector-python = { version = ">=3.10.0", optional = true } pyarrow = { version = ">=9.0.0,<17.0.0", optional = true } pandas = { version = ">=1.0.0,<3.0.0", optional = true } duckdb = { version = ">=0.5.0,<1.0.0", optional = true } @@ -177,10 +178,6 @@ ignore_missing_imports = true module = "tests.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "boto3" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "botocore.*" ignore_missing_imports = true @@ -205,10 +202,6 @@ ignore_missing_imports = true module = "duckdb.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "ray.*" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "daft.*" ignore_missing_imports = true @@ -237,18 +230,6 @@ ignore_missing_imports = true module = "sqlalchemy.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "Cython.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "setuptools.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "tenacity.*" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "pyarrow.*" ignore_missing_imports = true @@ -333,10 +314,6 @@ ignore_missing_imports = true module = "tests.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "boto3" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "botocore.*" ignore_missing_imports = true @@ -361,10 +338,6 @@ ignore_missing_imports = true module = "duckdb.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "ray.*" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "daft.*" ignore_missing_imports = true @@ -393,18 +366,6 @@ ignore_missing_imports = true module = "sqlalchemy.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "Cython.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "setuptools.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "tenacity.*" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "pyarrow.*" ignore_missing_imports = true @@ -489,10 +450,6 @@ ignore_missing_imports = true module = "tests.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "boto3" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "botocore.*" ignore_missing_imports = true @@ -517,10 +474,6 @@ ignore_missing_imports = true module = "duckdb.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "ray.*" -ignore_missing_imports = true - [[tool.mypy.overrides]] module = "daft.*" ignore_missing_imports = true @@ -549,18 +502,6 @@ ignore_missing_imports = true module = "sqlalchemy.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "Cython.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "setuptools.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "tenacity.*" -ignore_missing_imports = true - [tool.poetry.scripts] pyiceberg = "pyiceberg.cli.console:run" @@ -588,6 +529,7 @@ zstandard = ["zstandard"] sql-postgres = ["sqlalchemy", "psycopg2-binary"] sql-sqlite = ["sqlalchemy"] gcsfs = ["gcsfs"] +snowflake = ["snowflake-connector-python"] [tool.pytest.ini_options] markers = [ @@ -709,6 +651,10 @@ ignore_missing_imports = true module = "boto3" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "boto3.session" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "botocore.*" ignore_missing_imports = true @@ -777,5 +723,9 @@ ignore_missing_imports = true module = "tenacity.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "snowflake.connector.*" +ignore_missing_imports = true + [tool.coverage.run] source = ['pyiceberg/'] diff --git a/tests/catalog/test_snowflake_catalog.py b/tests/catalog/test_snowflake_catalog.py new file mode 100644 index 0000000000..b6ec0dbb01 --- /dev/null +++ b/tests/catalog/test_snowflake_catalog.py @@ -0,0 +1,229 @@ +import json +from typing import Any, Generator, List +from unittest.mock import MagicMock, patch + +import pytest + +from pyiceberg.catalog.snowflake_catalog import SnowflakeCatalog +from pyiceberg.table.metadata import TableMetadataUtil + + +class TestSnowflakeIdentifier: + def test_get_table_name(self) -> None: + sf_id = SnowflakeCatalog._SnowflakeIdentifier.table_from_string("db.schema.table") + assert sf_id.table_name == "db.schema.table" + + sf_id = SnowflakeCatalog._SnowflakeIdentifier.table_from_string("schema.table") + assert sf_id.table_name == "schema.table" + + sf_id = SnowflakeCatalog._SnowflakeIdentifier.table_from_string("table") + assert sf_id.table_name == "table" + + with pytest.raises(ValueError): + SnowflakeCatalog._SnowflakeIdentifier.table_from_string("db.schema.table.extra") + + def test_get_schema_name(self) -> None: + sf_id = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string("db.schema") + assert sf_id.schema_name == "db.schema" + + sf_id = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string("schema") + assert sf_id.schema_name == "schema" + + with pytest.raises(ValueError): + SnowflakeCatalog._SnowflakeIdentifier.schema_from_string("db.schema.extra") + + +class MockSnowflakeCursor: + q = "" + qs: List[Any] = [] + + def __enter__(self) -> Any: + return self + + def __exit__(self, *args: Any) -> None: + pass + + def rollback(self) -> None: + pass + + def fetchall(self) -> Any: + if "SHOW ICEBERG TABLES" in self.q: + return [ + { + "database_name": "db", + "schema_name": "schema", + "table_name": "tbl_1", + }, + { + "database_name": "db", + "schema_name": "schema", + "table_name": "tbl_2", + }, + ] + + return [] + + def fetchone(self) -> Any: + if "SYSTEM$GET_ICEBERG_TABLE_INFORMATION" in self.q: + return { + "METADATA": json.dumps({ + "metadataLocation": "s3://bucket/path/to/metadata.json", + }) + } + + def execute(self, *args: Any, **kwargs: Any) -> Any: + self.q = args[0] + self.qs.append(args) + + +class MockSnowflakeConnection: + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._cursor = MockSnowflakeCursor() + self._cursor.qs = [] + + def cursor(self, *args: Any, **kwargs: Any) -> Any: + return self._cursor + + +class MockCreds: + def get_frozen_credentials(self) -> Any: + creds = MagicMock() + + creds.access_key = "" + creds.secret_key = "" + creds.token = "" + + return creds + + +class TestSnowflakeCatalog: + @pytest.fixture(scope="function") + def snowflake_catalog(self) -> Generator[SnowflakeCatalog, None, None]: + with patch( + "pyiceberg.serializers.FromInputFile.table_metadata", + return_value=TableMetadataUtil.parse_obj({ + "format-version": 2, + "location": "s3://bucket/path/to/", + "last-column-id": 4, + "schemas": [{}], + "partition-specs": [{}], + }), + ): + with patch("pyiceberg.catalog.snowflake_catalog.Session.get_credentials", MockCreds): + with patch("pyiceberg.catalog.snowflake_catalog.SnowflakeConnection", MockSnowflakeConnection): + yield SnowflakeCatalog( + name="test", + user="", + account="", + ) + + def test_load_table(self, snowflake_catalog: SnowflakeCatalog) -> None: + tbl = snowflake_catalog.load_table("db.schema.table") + + assert tbl is not None + + def test_register_table(self, snowflake_catalog: SnowflakeCatalog) -> None: + qs = snowflake_catalog.connection._cursor.qs + + tbl = snowflake_catalog.register_table("db.schema.table", "s3://bucket/path/to/metadata.json") + + assert len(qs) == 2 + + assert qs[0][0] == "CREATE ICEBERG TABLE (%s) METADATA_FILE_PATH = (%s)" + assert qs[0][1] == ("db.schema.table", "s3://bucket/path/to/metadata.json") + + assert tbl is not None + + def test_drop_table(self, snowflake_catalog: SnowflakeCatalog) -> None: + snowflake_catalog.drop_table("db.schema.table") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 1 + + assert qs[0][0] == "DROP TABLE IF EXISTS (%s)" + assert qs[0][1] == ("db.schema.table",) + + def test_rename_table(self, snowflake_catalog: SnowflakeCatalog) -> None: + snowflake_catalog.rename_table("table", "schema.new_table") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 2 + + assert qs[0][0] == "ALTER TABLE (%s) RENAME TO (%s)" + assert qs[0][1] == ("table", "schema.new_table") + + def test_create_namespace_schema_only(self, snowflake_catalog: SnowflakeCatalog) -> None: + snowflake_catalog.create_namespace("schema") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 1 + + assert qs[0][0] == "CREATE SCHEMA IF NOT EXISTS (%s)" + assert qs[0][1] == ("schema",) + + def test_create_namespace_with_db(self, snowflake_catalog: SnowflakeCatalog) -> None: + snowflake_catalog.create_namespace("db.schema") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 2 + + assert qs[0][0] == "CREATE DATABASE IF NOT EXISTS (%s)" + assert qs[0][1] == ("db",) + + assert qs[1][0] == "CREATE SCHEMA IF NOT EXISTS (%s)" + assert qs[1][1] == ("db.schema",) + + def test_drop_namespace_schema_only(self, snowflake_catalog: SnowflakeCatalog) -> None: + snowflake_catalog.drop_namespace("schema") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 1 + + assert qs[0][0] == "DROP SCHEMA IF EXISTS (%s)" + assert qs[0][1] == ("schema",) + + def test_drop_namespace_with_db(self, snowflake_catalog: SnowflakeCatalog) -> None: + snowflake_catalog.drop_namespace("db.schema") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 2 + + assert qs[0][0] == "DROP DATABASE IF EXISTS (%s)" + assert qs[0][1] == ("db",) + + assert qs[1][0] == "DROP SCHEMA IF EXISTS (%s)" + assert qs[1][1] == ("db.schema",) + + def test_list_tables_schema_only(self, snowflake_catalog: SnowflakeCatalog) -> None: + tabs = snowflake_catalog.list_tables("schema") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 1 + + assert qs[0][0] == "SHOW ICEBERG TABLES IN SCHEMA (%s)" + assert qs[0][1] == ("schema",) + + assert len(tabs) == 2 + assert tabs[0] == ("db", "schema", "tbl_1") + assert tabs[1] == ("db", "schema", "tbl_2") + + def test_list_tables_with_db(self, snowflake_catalog: SnowflakeCatalog) -> None: + tabs = snowflake_catalog.list_tables("db.schema") + + qs = snowflake_catalog.connection._cursor.qs + + assert len(qs) == 1 + + assert qs[0][0] == "SHOW ICEBERG TABLES IN DATABASE (%s)" + assert qs[0][1] == ("db",) + + assert len(tabs) == 2 + assert tabs[0] == ("db", "schema", "tbl_1") + assert tabs[1] == ("db", "schema", "tbl_2")