Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support ORC and CSV in redshift.copy_from_files function #2849

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions awswrangler/redshift/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@
import json
import logging
import uuid
from typing import TYPE_CHECKING, Literal

import boto3
import botocore
import pandas as pd

from awswrangler import _data_types, _sql_utils, _utils, exceptions, s3

redshift_connector = _utils.import_optional_dependency("redshift_connector")
if TYPE_CHECKING:
try:
import redshift_connector
except ImportError:
pass
else:
redshift_connector = _utils.import_optional_dependency("redshift_connector")

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -217,6 +224,7 @@ def _validate_parameters(

def _redshift_types_from_path(
path: str | list[str],
data_format: Literal["parquet", "orc"],
varchar_lengths_default: int,
varchar_lengths: dict[str, int] | None,
parquet_infer_sampling: float,
Expand All @@ -229,16 +237,27 @@ def _redshift_types_from_path(
"""Extract Redshift data types from a Pandas DataFrame."""
_varchar_lengths: dict[str, int] = {} if varchar_lengths is None else varchar_lengths
_logger.debug("Scanning parquet schemas in S3 path: %s", path)
athena_types, _ = s3.read_parquet_metadata(
path=path,
sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
dataset=False,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
if data_format == "orc":
athena_types, _ = s3.read_orc_metadata(
path=path,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
dataset=False,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
else:
athena_types, _ = s3.read_parquet_metadata(
path=path,
sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
dataset=False,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
_logger.debug("Parquet metadata types: %s", athena_types)
redshift_types: dict[str, str] = {}
for col_name, col_type in athena_types.items():
Expand All @@ -248,7 +267,7 @@ def _redshift_types_from_path(
return redshift_types


def _create_table( # noqa: PLR0912,PLR0915
def _create_table( # noqa: PLR0912,PLR0913,PLR0915
df: pd.DataFrame | None,
path: str | list[str] | None,
con: "redshift_connector.Connection",
Expand All @@ -266,6 +285,8 @@ def _create_table( # noqa: PLR0912,PLR0915
primary_keys: list[str] | None,
varchar_lengths_default: int,
varchar_lengths: dict[str, int] | None,
data_format: Literal["parquet", "orc", "csv"] = "parquet",
redshift_column_types: dict[str, str] | None = None,
parquet_infer_sampling: float = 1.0,
path_suffix: str | None = None,
path_ignore_suffix: str | list[str] | None = None,
Expand Down Expand Up @@ -336,19 +357,28 @@ def _create_table( # noqa: PLR0912,PLR0915
path=path,
boto3_session=boto3_session,
)
redshift_types = _redshift_types_from_path(
path=path,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)

if data_format in ["parquet", "orc"]:
redshift_types = _redshift_types_from_path(
path=path,
data_format=data_format, # type: ignore[arg-type]
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
else:
if redshift_column_types is None:
raise ValueError(
"redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
)
redshift_types = redshift_column_types
else:
raise ValueError("df and path are None.You MUST pass at least one.")
raise ValueError("df and path are None. You MUST pass at least one.")
_validate_parameters(
redshift_types=redshift_types,
diststyle=diststyle,
Expand Down
72 changes: 54 additions & 18 deletions awswrangler/redshift/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from typing import Literal
from typing import TYPE_CHECKING, Literal, get_args

import boto3

Expand All @@ -15,21 +15,29 @@
from ._connect import _validate_connection
from ._utils import _create_table, _make_s3_auth_string, _upsert

redshift_connector = _utils.import_optional_dependency("redshift_connector")
if TYPE_CHECKING:
try:
import redshift_connector
except ImportError:
pass
else:
redshift_connector = _utils.import_optional_dependency("redshift_connector")

_logger: logging.Logger = logging.getLogger(__name__)

_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "delete"]
_ToSqlDistStyleLiteral = Literal["AUTO", "EVEN", "ALL", "KEY"]
_ToSqlSortStyleLiteral = Literal["COMPOUND", "INTERLEAVED"]
_CopyFromFilesDataFormatLiteral = Literal["parquet", "orc", "csv"]


def _copy(
cursor: "redshift_connector.Cursor", # type: ignore[name-defined]
cursor: "redshift_connector.Cursor",
path: str,
table: str,
serialize_to_json: bool,
data_format: _CopyFromFilesDataFormatLiteral = "parquet",
iam_role: str | None = None,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
Expand All @@ -45,6 +53,11 @@ def _copy(
else:
table_name = f'"{schema}"."{table}"'

if data_format not in ["parquet", "orc"] and serialize_to_json:
raise exceptions.InvalidArgumentCombination(
"You can only use SERIALIZETOJSON with data_format='parquet' or 'orc'."
)

auth_str: str = _make_s3_auth_string(
iam_role=iam_role,
aws_access_key_id=aws_access_key_id,
Expand All @@ -54,7 +67,9 @@ def _copy(
)
ser_json_str: str = " SERIALIZETOJSON" if serialize_to_json else ""
column_names_str: str = f"({','.join(column_names)})" if column_names else ""
sql = f"COPY {table_name} {column_names_str}\nFROM '{path}' {auth_str}\nFORMAT AS PARQUET{ser_json_str}"
sql = (
f"COPY {table_name} {column_names_str}\nFROM '{path}' {auth_str}\nFORMAT AS {data_format.upper()}{ser_json_str}"
)

if manifest:
sql += "\nMANIFEST"
Expand All @@ -68,7 +83,7 @@ def _copy(
@apply_configs
def to_sql(
df: pd.DataFrame,
con: "redshift_connector.Connection", # type: ignore[name-defined]
con: "redshift_connector.Connection",
table: str,
schema: str,
mode: _ToSqlModeLiteral = "append",
Expand Down Expand Up @@ -240,13 +255,15 @@ def to_sql(
@_utils.check_optional_dependency(redshift_connector, "redshift_connector")
def copy_from_files( # noqa: PLR0913
path: str,
con: "redshift_connector.Connection", # type: ignore[name-defined]
con: "redshift_connector.Connection",
table: str,
schema: str,
iam_role: str | None = None,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
data_format: _CopyFromFilesDataFormatLiteral = "parquet",
redshift_column_types: dict[str, str] | None = None,
parquet_infer_sampling: float = 1.0,
mode: _ToSqlModeLiteral = "append",
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
Expand All @@ -270,16 +287,19 @@ def copy_from_files( # noqa: PLR0913
precombine_key: str | None = None,
column_names: list[str] | None = None,
) -> None:
"""Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).
"""Load files from S3 to a Table on Amazon Redshift (Through COPY command).

https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html

Note
----
If the table does not exist yet,
it will be automatically created for you
using the Parquet metadata to
using the Parquet/ORC/CSV metadata to
infer the columns data types.
If the data is in the CSV format,
the Redshift column types need to be
specified manually using ``redshift_column_types``.

Note
----
Expand All @@ -305,6 +325,15 @@ def copy_from_files( # noqa: PLR0913
The secret key for your AWS account.
aws_session_token : str, optional
The session key for your AWS account. This is only needed when you are using temporary credentials.
data_format: str, optional
Data format to be loaded.
Supported values are Parquet, ORC, and CSV.
Default is Parquet.
redshift_column_types: dict, optional
Dictionary with keys as column names and values as Redshift column types.
Only used when ``data_format`` is CSV.

e.g. ```{'col1': 'BIGINT', 'col2': 'VARCHAR(256)'}```
parquet_infer_sampling : float
Random sample ratio of files that will have the metadata inspected.
Must be `0.0 < sampling <= 1.0`.
Expand Down Expand Up @@ -382,25 +411,30 @@ def copy_from_files( # noqa: PLR0913
Examples
--------
>>> import awswrangler as wr
>>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
>>> wr.redshift.copy_from_files(
... path="s3://bucket/my_parquet_files/",
... con=con,
... table="my_table",
... schema="public",
... iam_role="arn:aws:iam::XXX:role/XXX"
... )
>>> con.close()
>>> with wr.redshift.connect("MY_GLUE_CONNECTION") as con:
... wr.redshift.copy_from_files(
... path="s3://bucket/my_parquet_files/",
... con=con,
... table="my_table",
... schema="public",
... iam_role="arn:aws:iam::XXX:role/XXX"
... )

"""
_logger.debug("Copying objects from S3 path: %s", path)

data_format = data_format.lower() # type: ignore[assignment]
if data_format not in get_args(_CopyFromFilesDataFormatLiteral):
raise exceptions.InvalidArgumentValue(f"The specified data_format {data_format} is not supported.")

autocommit_temp: bool = con.autocommit
con.autocommit = False
try:
with con.cursor() as cursor:
created_table, created_schema = _create_table(
df=None,
path=path,
data_format=data_format,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
Expand All @@ -410,6 +444,7 @@ def copy_from_files( # noqa: PLR0913
schema=schema,
mode=mode,
overwrite_method=overwrite_method,
redshift_column_types=redshift_column_types,
diststyle=diststyle,
sortstyle=sortstyle,
distkey=distkey,
Expand All @@ -431,6 +466,7 @@ def copy_from_files( # noqa: PLR0913
table=created_table,
schema=created_schema,
iam_role=iam_role,
data_format=data_format,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
Expand Down Expand Up @@ -467,7 +503,7 @@ def copy_from_files( # noqa: PLR0913
def copy( # noqa: PLR0913
df: pd.DataFrame,
path: str,
con: "redshift_connector.Connection", # type: ignore[name-defined]
con: "redshift_connector.Connection",
table: str,
schema: str,
iam_role: str | None = None,
Expand Down
Loading
Loading