Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sql to gcs with exclude columns #23695

Merged
merged 12 commits into from
May 22, 2022
16 changes: 14 additions & 2 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class BaseSQLToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param exclude_columns: set of columns to exclude from transmission
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -103,9 +104,13 @@ def __init__(
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
exclude_columns=None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if exclude_columns is None:
exclude_columns = set()

self.sql = sql
self.bucket = bucket
self.filename = filename
Expand All @@ -120,6 +125,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.exclude_columns = exclude_columns

def execute(self, context: 'Context'):
self.log.info("Executing query")
Expand Down Expand Up @@ -165,7 +171,9 @@ def _write_local_data_files(self, cursor):
names in GCS, and values are file handles to local files that
contain the data for the GCS objects.
"""
schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
schema = [column for column in org_schema if column not in self.exclude_columns]

Comment on lines +174 to +176
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider casting excluded_columns to set in constructor. Then you can do:

Suggested change
org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
schema = [column for column in org_schema if column not in self.exclude_columns]
org_schema = set(schema_tuple[0] for schema_tuple in cursor.description)
schema = org_schema - self.exclude_columns

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If schema is changed to set type, the order of the columns is changed.
It seems better to use a list to keep the order of the columns.

however, it seems good to receive the input value as a set to prevent duplication of the input exclude column.

            delegate_to: Optional[str] = None,
            impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
            exclude_columns=None,
            **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        if exclude_columns is None:
            exclude_columns = set()

using exclude_columns as a set to prevent duplication

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, consider exclude_columns = exclude_columns or set()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@turbaszek
I changed the input type, please check

col_type_dict = self._get_col_type_dict()
file_no = 0

Expand Down Expand Up @@ -314,7 +322,11 @@ def _write_local_schema_file(self, cursor):
schema = self.schema
else:
self.log.info("Starts generating schema")
schema = [self.field_to_bigquery(field) for field in cursor.description]
schema = [
self.field_to_bigquery(field)
for field in cursor.description
if field[0] not in self.exclude_columns
]

if isinstance(schema, list):
schema = json.dumps(schema, sort_keys=True)
Expand Down
Binary file removed tests/providers/google/cloud/transfers/temp-file
Binary file not shown.
29 changes: 29 additions & 0 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@

OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)

EXCLUDE_COLUMNS = set('column_c')
NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS]
OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame(
[['convert_type_return_value'] * len(NEW_COLUMNS)] * 3, columns=NEW_COLUMNS
)


class DummySQLToGCSOperator(BaseSQLToGCSOperator):
def field_to_bigquery(self, field) -> Dict[str, str]:
Expand Down Expand Up @@ -287,3 +293,26 @@ def test__write_local_data_files_parquet(self):
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)

def test__write_local_data_files_json_with_exclude_columns(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="json",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
exclude_columns=EXCLUDE_COLUMNS,
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = next(files)['file_handle']
file.flush()
df = pd.read_json(file.name, orient='records', lines=True)
assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)