Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sql to gcs with exclude columns #23695

Merged
merged 12 commits into from
May 22, 2022
53 changes: 33 additions & 20 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import abc
import json
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union, List

import pyarrow as pa
import pyarrow.parquet as pq
Expand Down Expand Up @@ -71,6 +71,7 @@ class BaseSQLToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param exclude_columns: list of columns to exclude from transmission
"""

template_fields: Sequence[str] = (
Expand All @@ -87,23 +88,24 @@ class BaseSQLToGCSOperator(BaseOperator):
ui_color = '#a0e08c'

def __init__(
self,
*,
sql: str,
bucket: str,
filename: str,
schema_filename: Optional[str] = None,
approx_max_file_size_bytes: int = 1900000000,
export_format: str = 'json',
field_delimiter: str = ',',
null_marker: Optional[str] = None,
gzip: bool = False,
schema: Optional[Union[str, list]] = None,
parameters: Optional[dict] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
self,
*,
sql: str,
bucket: str,
filename: str,
schema_filename: Optional[str] = None,
approx_max_file_size_bytes: int = 1900000000,
export_format: str = 'json',
field_delimiter: str = ',',
null_marker: Optional[str] = None,
gzip: bool = False,
schema: Optional[Union[str, list]] = None,
parameters: Optional[dict] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
exclude_columns: List[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.sql = sql
Expand All @@ -120,6 +122,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.exclude_columns = exclude_columns

def execute(self, context: 'Context'):
self.log.info("Executing query")
Expand Down Expand Up @@ -165,7 +168,13 @@ def _write_local_data_files(self, cursor):
names in GCS, and values are file handles to local files that
contain the data for the GCS objects.
"""
schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))

if self.exclude_columns is None:
schema = org_schema
else:
schema = [column for column in org_schema if column not in self.exclude_columns]

col_type_dict = self._get_col_type_dict()
file_no = 0

Expand Down Expand Up @@ -314,7 +323,11 @@ def _write_local_schema_file(self, cursor):
schema = self.schema
else:
self.log.info("Starts generating schema")
schema = [self.field_to_bigquery(field) for field in cursor.description]
if self.exclude_columns is None:
schema = [self.field_to_bigquery(field) for field in cursor.description]
else:
schema = [self.field_to_bigquery(field) for field in cursor.description if
field[0] not in self.exclude_columns]

if isinstance(schema, list):
schema = json.dumps(schema, sort_keys=True)
Expand Down
Binary file removed tests/providers/google/cloud/transfers/temp-file
Binary file not shown.
28 changes: 28 additions & 0 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@

OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)

EXCLUDE_COLUMNS = ['column_c']
NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS]
OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame([['convert_type_return_value'] * len(NEW_COLUMNS)] * 3,
columns=NEW_COLUMNS)


class DummySQLToGCSOperator(BaseSQLToGCSOperator):
def field_to_bigquery(self, field) -> Dict[str, str]:
Expand Down Expand Up @@ -287,3 +292,26 @@ def test__write_local_data_files_parquet(self):
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)

def test__write_local_data_files_json_with_exclude_columns(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="json",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
exclude_columns=EXCLUDE_COLUMNS,
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = next(files)['file_handle']
file.flush()
df = pd.read_json(file.name, orient='records', lines=True)
assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)