Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to 'smbprotocol' library #17273

Merged
merged 7 commits into from
Aug 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 205 additions & 40 deletions airflow/providers/samba/hooks/samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,67 +16,232 @@
# specific language governing permissions and limitations
# under the License.

import os
import posixpath
from functools import wraps
from shutil import copyfileobj
from typing import Optional

from smbclient import SambaClient
import smbclient

from airflow.hooks.base import BaseHook


class SambaHook(BaseHook):
"""Allows for interaction with an samba server."""
"""Allows for interaction with a Samba server.

The hook should be used as a context manager in order to correctly
set up a session and disconnect open connections upon exit.

:param samba_conn_id: The connection id reference.
:type samba_conn_id: str
:param share:
An optional share name. If this is unset then the "schema" field of
the connection is used in its place.
:type share: str
"""

conn_name_attr = 'samba_conn_id'
default_conn_name = 'samba_default'
conn_type = 'samba'
hook_name = 'Samba'

def __init__(self, samba_conn_id: str = default_conn_name) -> None:
def __init__(self, samba_conn_id: str = default_conn_name, share: Optional[str] = None) -> None:
super().__init__()
self.conn = self.get_connection(samba_conn_id)
conn = self.get_connection(samba_conn_id)

if not conn.login:
self.log.info("Login not provided")

if not conn.password:
self.log.info("Password not provided")

self._host = conn.host
self._share = share or conn.schema
self._connection_cache = connection_cache = {}
self._conn_kwargs = {
"username": conn.login,
"password": conn.password,
"port": conn.port or 445,
"connection_cache": connection_cache,
malthe marked this conversation as resolved.
Show resolved Hide resolved
}

def __enter__(self):
# This immediately connects to the host (which can be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read that correctly - it means that SambaHook MUST be used as ContextManager in some cases, otherwise there might some problems with initializing some parameters during constructor? This is not a usual pattern we have in Airflow for hooks (though I think it's nice pattern for Hooks) but I think some explanation is needed at least in the docstring explaining the difference between the two and when to use it?

Also - would you mind to add a CHANGELOG.txt entry? Don't yet put the version (I will update it) but some backwards-compatibility notes are needed (how to migrate?)

Copy link
Contributor Author

@malthe malthe Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@potiuk where would those notes go? I don't see other changelog entries with such information. I have added a simple changelog entry for now and a note in the docstring about using it as a context manager.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could see some of those in other providers - for example look at the Google provider: https://github.com/apache/airflow/blob/main/airflow/providers/google/CHANGELOG.rst - usually when there are braking changes that require more than one-line mentioning.

Even in Samba, we've added a warning when we removed "apply_default" (this was a global one and it actually did NOT really apply to samba but it went there as it was "all operator" change.

https://github.com/apache/airflow/blob/main/airflow/providers/samba/CHANGELOG.rst

# perceived as a benefit), but also help work around an issue:
#
# https://github.com/jborean93/smbprotocol/issues/109.
smbclient.register_session(self._host, **self._conn_kwargs)
return self

def __exit__(self, exc_type, exc_value, traceback):
for host, connection in self._connection_cache.items():
self.log.info("Disconnecting from %s", host)
connection.disconnect()
self._connection_cache.clear()

def _join_path(self, path):
return f"//{posixpath.join(self._host, self._share, path)}"

@wraps(smbclient.link)
def link(self, src, dst, follow_symlinks=True):
return smbclient.link(
self._join_path(src),
self._join_path(dst),
follow_symlinks=follow_symlinks,
**self._conn_kwargs,
)

@wraps(smbclient.listdir)
def listdir(self, path):
return smbclient.listdir(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.lstat)
def lstat(self, path):
return smbclient.lstat(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.makedirs)
def makedirs(self, path, exist_ok=False):
return smbclient.makedirs(self._join_path(path), exist_ok=exist_ok, **self._conn_kwargs)

@wraps(smbclient.mkdir)
def mkdir(self, path):
return smbclient.mkdir(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.open_file)
def open_file(
self,
path,
mode="r",
buffering=-1,
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type="file",
):
return smbclient.open_file(
self._join_path(path),
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
share_access=share_access,
desired_access=desired_access,
file_attributes=file_attributes,
file_type=file_type,
**self._conn_kwargs,
)

@wraps(smbclient.readlink)
def readlink(self, path):
return smbclient.readlink(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.remove)
def remove(self, path):
return smbclient.remove(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.removedirs)
def removedirs(self, path):
return smbclient.removedirs(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.rename)
def rename(self, src, dst):
return smbclient.rename(self._join_path(src), self._join_path(dst), **self._conn_kwargs)

def get_conn(self) -> SambaClient:
"""
Return a samba client object.
@wraps(smbclient.replace)
def replace(self, src, dst):
return smbclient.replace(self._join_path(src), self._join_path(dst), **self._conn_kwargs)

You can provide optional parameters in the extra fields of
your connection.
@wraps(smbclient.rmdir)
def rmdir(self, path):
return smbclient.rmdir(self._join_path(path), **self._conn_kwargs)

Below is an inexhaustive list of these parameters:
@wraps(smbclient.scandir)
def scandir(self, path, search_pattern="*"):
return smbclient.scandir(
self._join_path(path),
search_pattern=search_pattern,
**self._conn_kwargs,
)

`logdir`
Base directory name for log/debug files.
@wraps(smbclient.stat)
def stat(self, path, follow_symlinks=True):
return smbclient.stat(self._join_path(path), follow_symlinks=follow_symlinks, **self._conn_kwargs)

@wraps(smbclient.stat_volume)
def stat_volume(self, path):
return smbclient.stat_volume(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.symlink)
def symlink(self, src, dst, target_is_directory=False):
return smbclient.symlink(
self._join_path(src),
self._join_path(dst),
target_is_directory=target_is_directory,
**self._conn_kwargs,
)

`kerberos`
Try to authenticate with kerberos.
@wraps(smbclient.truncate)
def truncate(self, path, length):
return smbclient.truncate(self._join_path(path), length, **self._conn_kwargs)

@wraps(smbclient.unlink)
def unlink(self, path):
return smbclient.unlink(self._join_path(path), **self._conn_kwargs)

@wraps(smbclient.utime)
def utime(self, path, times=None, ns=None, follow_symlinks=True):
return smbclient.utime(
self._join_path(path),
times=times,
ns=ns,
follow_symlinks=follow_symlinks,
**self._conn_kwargs,
)

`workgroup`
Set the SMB domain of the username.
@wraps(smbclient.walk)
def walk(self, path, topdown=True, onerror=None, follow_symlinks=False):
return smbclient.walk(
self._join_path(path),
topdown=topdown,
onerror=onerror,
follow_symlinks=follow_symlinks,
**self._conn_kwargs,
)

`netbios_name`
This option allows you to override the NetBIOS name that
Samba uses for itself.
@wraps(smbclient.getxattr)
def getxattr(self, path, attribute, follow_symlinks=True):
return smbclient.getxattr(
self._join_path(path), attribute, follow_symlinks=follow_symlinks, **self._conn_kwargs
)

@wraps(smbclient.listxattr)
def listxattr(self, path, follow_symlinks=True):
return smbclient.listxattr(
self._join_path(path), follow_symlinks=follow_symlinks, **self._conn_kwargs
)

@wraps(smbclient.removexattr)
def removexattr(self, path, attribute, follow_symlinks=True):
return smbclient.removexattr(
self._join_path(path), attribute, follow_symlinks=follow_symlinks, **self._conn_kwargs
)

For additional details, see `smbclient.SambaClient`.
"""
samba = SambaClient(
server=self.conn.host,
share=self.conn.schema,
username=self.conn.login,
ip=self.conn.host,
password=self.conn.password,
**self.conn.extra_dejson,
@wraps(smbclient.setxattr)
def setxattr(self, path, attribute, value, flags=0, follow_symlinks=True):
return smbclient.setxattr(
self._join_path(path),
attribute,
value,
flags=flags,
follow_symlinks=follow_symlinks,
**self._conn_kwargs,
)
return samba

def push_from_local(self, destination_filepath: str, local_filepath: str) -> None:
def push_from_local(self, destination_filepath: str, local_filepath: str):
"""Push local file to samba server"""
samba = self.get_conn()
if samba.exists(destination_filepath):
if samba.isfile(destination_filepath):
samba.remove(destination_filepath)
else:
folder = os.path.dirname(destination_filepath)
if not samba.exists(folder):
samba.mkdir(folder)
samba.upload(local_filepath, destination_filepath)
with open(local_filepath, "rb") as f, self.open_file(destination_filepath, mode="w") as g:
copyfileobj(f, g)
2 changes: 1 addition & 1 deletion docs/apache-airflow-providers-samba/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ PIP requirements
PIP package Version required
================== ==================
``apache-airflow`` ``>=2.1.0``
``pysmbclient`` ``>=0.1.3``
``smbprotocol`` ``>=1.5.0``
================== ==================

.. include:: ../../airflow/providers/samba/CHANGELOG.rst
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'tableauserverclient',
]
samba = [
'pysmbclient>=0.1.3',
'smbprotocol>=1.5.0',
]
segment = [
'analytics-python>=1.2.9',
Expand Down
Loading