Skip to content

Commit

Permalink
Added VARBINARY Definition (#37)
Browse files Browse the repository at this point in the history
* Linting edits

* Added VARBINARY Definition

* changed from sql_type to from_type

* added base64encoding
  • Loading branch information
BuzzCutNorman authored Apr 24, 2023
1 parent e27930e commit ff9d00d
Showing 1 changed file with 81 additions and 13 deletions.
94 changes: 81 additions & 13 deletions tap_mssql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import datetime

from base64 import b64encode
from decimal import Decimal
from uuid import uuid4
from typing import Any, Dict, Iterable, Optional
Expand All @@ -30,13 +31,12 @@
class mssqlConnector(SQLConnector):
"""Connects to the mssql SQL source."""

def __init__(self, config: dict | None = None, sqlalchemy_url: str | None = None) -> None:
"""Initialize the mssqlConnector.
Args:
config: The parent tap or target object's config.
sqlalchemy_url: Optional URL for the connection.
"""
def __init__(
self,
config: dict | None = None,
sqlalchemy_url: str | None = None
) -> None:
"""Class Default Init"""
# If pyodbc given set pyodbc.pooling to False
# This allows SQLA to manage to connection pool
if config['driver_type'] == 'pyodbc':
Expand Down Expand Up @@ -77,7 +77,9 @@ def get_sqlalchemy_url(cls, config: dict) -> str:
config_url = config_url.set(port=config['port'])

if 'sqlalchemy_url_query' in config:
config_url = config_url.update_query_dict(config['sqlalchemy_url_query'])
config_url = config_url.update_query_dict(
config['sqlalchemy_url_query']
)

return (config_url)

Expand Down Expand Up @@ -193,7 +195,9 @@ def hd_to_jsonschema_type(
):
sql_type_name = from_type.__name__
else:
raise ValueError("Expected `str` or a SQLAlchemy `TypeEngine` object or type.")
raise ValueError(
"Expected `str` or a SQLAlchemy `TypeEngine` object or type."
)

# Add in the length of the
if sql_type_name in ['CHAR', 'NCHAR', 'VARCHAR', 'NVARCHAR']:
Expand Down Expand Up @@ -223,6 +227,20 @@ def hd_to_jsonschema_type(
"contentMediaType": "application/xml",
}

if sql_type_name in ['BINARY', 'IMAGE', 'VARBINARY']:
maxLength: int = getattr(from_type, 'length')
if getattr(from_type, 'length'):
return {
"type": ["string"],
"contentEncoding": "base64",
"maxLength": maxLength
}
else:
return {
"type": ["string"],
"contentEncoding": "base64",
}

# This is a MSSQL only DataType
# SQLA does the converion from 0,1
# to Python True, False
Expand Down Expand Up @@ -387,20 +405,70 @@ class mssqlStream(SQLStream):
connector_class = mssqlConnector
encoder_class = CustomJSONEncoder

def get_records(self, partition: Optional[dict]) -> Iterable[Dict[str, Any]]:
def post_process(
self,
row: dict,
context: dict | None = None, # noqa: ARG002
) -> dict | None:
"""As needed, append or transform raw data to match expected structure.
Optional. This method gives developers an opportunity to "clean up" the results
prior to returning records to the downstream tap - for instance: cleaning,
renaming, or appending properties to the raw record result returned from the
API.
Developers may also return `None` from this method to filter out
invalid or not-applicable records from the stream.
Args:
row: Individual record in the stream.
context: Stream partition or context dictionary.
Returns:
The resulting record dict, or `None` if the record should be excluded.
"""
# We change the name to record so when the change breaking
# change from row to record is done in SDK 1.0 the edits
# to accomidate the swithc will be two
record: dict = row

# Get the Stream Properties Dictornary from the Schema
properties: dict = self.schema.get('properties')

for key, value in record.items():
# Get the Item/Column property
property_schema: dict = properties.get(key)
# Encode base64 binary fields in the record
if property_schema.get('contentEncoding') == 'base64':
record.update({key: b64encode(value)})

return record

def get_records(
self,
partition: Optional[dict]
) -> Iterable[Dict[str, Any]]:
"""Return a generator of record-type dictionary objects.
Developers may optionally add custom logic before calling the default
implementation inherited from the base class.
Args:
partition: If provided, will read specifically from this data slice.
partition: If provided, will read only from this data slice.
Yields:
One dict per record.
"""

yield from super().get_records(partition)
# I took some of the get_records from the rest.py and added it
# here so I can edit the records in the post_process method
# before they are sent to the tap.
for record in super().get_records(partition):
transformed_record = self.post_process(record)
if transformed_record is None:
# Record filtered out during post_process()
continue
yield transformed_record
# yield from super().get_records(partition)

def get_batches(
self,
Expand Down

0 comments on commit ff9d00d

Please sign in to comment.