Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added VARBINARY Definition #37

Merged
merged 6 commits into from
Apr 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 81 additions & 13 deletions tap_mssql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import datetime

from base64 import b64encode
from decimal import Decimal
from uuid import uuid4
from typing import Any, Dict, Iterable, Optional
Expand All @@ -30,13 +31,12 @@
class mssqlConnector(SQLConnector):
"""Connects to the mssql SQL source."""

def __init__(self, config: dict | None = None, sqlalchemy_url: str | None = None) -> None:
"""Initialize the mssqlConnector.
Args:
config: The parent tap or target object's config.
sqlalchemy_url: Optional URL for the connection.
"""
def __init__(
self,
config: dict | None = None,
sqlalchemy_url: str | None = None
) -> None:
"""Class Default Init"""
# If pyodbc given set pyodbc.pooling to False
# This allows SQLA to manage to connection pool
if config['driver_type'] == 'pyodbc':
Expand Down Expand Up @@ -77,7 +77,9 @@ def get_sqlalchemy_url(cls, config: dict) -> str:
config_url = config_url.set(port=config['port'])

if 'sqlalchemy_url_query' in config:
config_url = config_url.update_query_dict(config['sqlalchemy_url_query'])
config_url = config_url.update_query_dict(
config['sqlalchemy_url_query']
)

return (config_url)

Expand Down Expand Up @@ -193,7 +195,9 @@ def hd_to_jsonschema_type(
):
sql_type_name = from_type.__name__
else:
raise ValueError("Expected `str` or a SQLAlchemy `TypeEngine` object or type.")
raise ValueError(
"Expected `str` or a SQLAlchemy `TypeEngine` object or type."
)

# Add in the length of the
if sql_type_name in ['CHAR', 'NCHAR', 'VARCHAR', 'NVARCHAR']:
Expand Down Expand Up @@ -223,6 +227,20 @@ def hd_to_jsonschema_type(
"contentMediaType": "application/xml",
}

if sql_type_name in ['BINARY', 'IMAGE', 'VARBINARY']:
maxLength: int = getattr(from_type, 'length')
if getattr(from_type, 'length'):
return {
"type": ["string"],
"contentEncoding": "base64",
"maxLength": maxLength
}
else:
return {
"type": ["string"],
"contentEncoding": "base64",
}

# This is a MSSQL only DataType
# SQLA does the converion from 0,1
# to Python True, False
Expand Down Expand Up @@ -387,20 +405,70 @@ class mssqlStream(SQLStream):
connector_class = mssqlConnector
encoder_class = CustomJSONEncoder

def get_records(self, partition: Optional[dict]) -> Iterable[Dict[str, Any]]:
def post_process(
self,
row: dict,
context: dict | None = None, # noqa: ARG002
) -> dict | None:
"""As needed, append or transform raw data to match expected structure.
Optional. This method gives developers an opportunity to "clean up" the results
prior to returning records to the downstream tap - for instance: cleaning,
renaming, or appending properties to the raw record result returned from the
API.
Developers may also return `None` from this method to filter out
invalid or not-applicable records from the stream.
Args:
row: Individual record in the stream.
context: Stream partition or context dictionary.
Returns:
The resulting record dict, or `None` if the record should be excluded.
"""
# We change the name to record so when the change breaking
# change from row to record is done in SDK 1.0 the edits
# to accomidate the swithc will be two
record: dict = row

# Get the Stream Properties Dictornary from the Schema
properties: dict = self.schema.get('properties')

for key, value in record.items():
# Get the Item/Column property
property_schema: dict = properties.get(key)
# Encode base64 binary fields in the record
if property_schema.get('contentEncoding') == 'base64':
record.update({key: b64encode(value)})

return record

def get_records(
self,
partition: Optional[dict]
) -> Iterable[Dict[str, Any]]:
"""Return a generator of record-type dictionary objects.
Developers may optionally add custom logic before calling the default
implementation inherited from the base class.
Args:
partition: If provided, will read specifically from this data slice.
partition: If provided, will read only from this data slice.
Yields:
One dict per record.
"""

yield from super().get_records(partition)
# I took some of the get_records from the rest.py and added it
# here so I can edit the records in the post_process method
# before they are sent to the tap.
for record in super().get_records(partition):
transformed_record = self.post_process(record)
if transformed_record is None:
# Record filtered out during post_process()
continue
yield transformed_record
# yield from super().get_records(partition)

def get_batches(
self,
Expand Down