Skip to content

Commit

Permalink
tests: convert remote unit tests to pytest (#3604)
Browse files Browse the repository at this point in the history
- update tests/unit/remote tests to use pytest style tests
- use pytest dvc repo fixtures instead of improperly mocked repo's or
  None
  • Loading branch information
pmrowla authored Apr 7, 2020
1 parent ca60c8e commit dfe184f
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 174 deletions.
8 changes: 4 additions & 4 deletions tests/remotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ class S3Mocked(S3):

@classmethod
@contextmanager
def remote(cls):
def remote(cls, repo):
with mock_s3():
yield RemoteS3(None, {"url": cls.get_url()})
yield RemoteS3(repo, {"url": cls.get_url()})

@staticmethod
def put_objects(remote, objects):
Expand Down Expand Up @@ -127,8 +127,8 @@ def get_url():

@classmethod
@contextmanager
def remote(cls):
yield RemoteGS(None, {"url": cls.get_url()})
def remote(cls, repo):
yield RemoteGS(repo, {"url": cls.get_url()})

@staticmethod
def put_objects(remote, objects):
Expand Down
65 changes: 32 additions & 33 deletions tests/unit/remote/ssh/test_ssh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import getpass
import os
import sys
from unittest import TestCase

import pytest
from mock import mock_open
Expand All @@ -12,31 +11,31 @@
from tests.remotes import SSHMocked


class TestRemoteSSH(TestCase):
def test_url(self):
user = "test"
host = "123.45.67.89"
port = 1234
path = "/path/to/dir"
def test_url(dvc):
user = "test"
host = "123.45.67.89"
port = 1234
path = "/path/to/dir"

# URL ssh://[user@]host.xz[:port]/path
url = "ssh://{}@{}:{}{}".format(user, host, port, path)
config = {"url": url}
# URL ssh://[user@]host.xz[:port]/path
url = "ssh://{}@{}:{}{}".format(user, host, port, path)
config = {"url": url}

remote = RemoteSSH(None, config)
self.assertEqual(remote.path_info, url)
remote = RemoteSSH(dvc, config)
assert remote.path_info == url

# SCP-like URL ssh://[user@]host.xz:/absolute/path
url = "ssh://{}@{}:{}".format(user, host, path)
config = {"url": url}
# SCP-like URL ssh://[user@]host.xz:/absolute/path
url = "ssh://{}@{}:{}".format(user, host, path)
config = {"url": url}

remote = RemoteSSH(None, config)
self.assertEqual(remote.path_info, url)
remote = RemoteSSH(dvc, config)
assert remote.path_info == url

def test_no_path(self):
config = {"url": "ssh://127.0.0.1"}
remote = RemoteSSH(None, config)
self.assertEqual(remote.path_info.path, "")

def test_no_path(dvc):
config = {"url": "ssh://127.0.0.1"}
remote = RemoteSSH(dvc, config)
assert remote.path_info.path == ""


mock_ssh_config = """
Expand Down Expand Up @@ -67,9 +66,9 @@ def test_no_path(self):
read_data=mock_ssh_config,
)
def test_ssh_host_override_from_config(
mock_file, mock_exists, config, expected_host
mock_file, mock_exists, dvc, config, expected_host
):
remote = RemoteSSH(None, config)
remote = RemoteSSH(dvc, config)

mock_exists.assert_called_with(RemoteSSH.ssh_config_filename())
mock_file.assert_called_with(RemoteSSH.ssh_config_filename())
Expand All @@ -96,8 +95,8 @@ def test_ssh_host_override_from_config(
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_user(mock_file, mock_exists, config, expected_user):
remote = RemoteSSH(None, config)
def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user):
remote = RemoteSSH(dvc, config)

mock_exists.assert_called_with(RemoteSSH.ssh_config_filename())
mock_file.assert_called_with(RemoteSSH.ssh_config_filename())
Expand All @@ -121,8 +120,8 @@ def test_ssh_user(mock_file, mock_exists, config, expected_user):
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_port(mock_file, mock_exists, config, expected_port):
remote = RemoteSSH(None, config)
def test_ssh_port(mock_file, mock_exists, dvc, config, expected_port):
remote = RemoteSSH(dvc, config)

mock_exists.assert_called_with(RemoteSSH.ssh_config_filename())
mock_file.assert_called_with(RemoteSSH.ssh_config_filename())
Expand Down Expand Up @@ -156,8 +155,8 @@ def test_ssh_port(mock_file, mock_exists, config, expected_port):
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_keyfile(mock_file, mock_exists, config, expected_keyfile):
remote = RemoteSSH(None, config)
def test_ssh_keyfile(mock_file, mock_exists, dvc, config, expected_keyfile):
remote = RemoteSSH(dvc, config)

mock_exists.assert_called_with(RemoteSSH.ssh_config_filename())
mock_file.assert_called_with(RemoteSSH.ssh_config_filename())
Expand All @@ -178,15 +177,15 @@ def test_ssh_keyfile(mock_file, mock_exists, config, expected_keyfile):
new_callable=mock_open,
read_data=mock_ssh_config,
)
def test_ssh_gss_auth(mock_file, mock_exists, config, expected_gss_auth):
remote = RemoteSSH(None, config)
def test_ssh_gss_auth(mock_file, mock_exists, dvc, config, expected_gss_auth):
remote = RemoteSSH(dvc, config)

mock_exists.assert_called_with(RemoteSSH.ssh_config_filename())
mock_file.assert_called_with(RemoteSSH.ssh_config_filename())
assert remote.gss_auth == expected_gss_auth


def test_hardlink_optimization(tmp_dir, ssh_server):
def test_hardlink_optimization(dvc, tmp_dir, ssh_server):
port = ssh_server.test_creds["port"]
user = ssh_server.test_creds["username"]

Expand All @@ -196,7 +195,7 @@ def test_hardlink_optimization(tmp_dir, ssh_server):
"user": user,
"keyfile": ssh_server.test_creds["key_filename"],
}
remote = RemoteSSH(None, config)
remote = RemoteSSH(dvc, config)

from_info = remote.path_info / "empty"
to_info = remote.path_info / "link"
Expand Down
52 changes: 25 additions & 27 deletions tests/unit/remote/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
from unittest import TestCase

from dvc.remote.azure import RemoteAZURE


class TestRemoteAZURE(TestCase):
container_name = "container-name"
connection_string = (
"DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;"
"AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsu"
"Fq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;"
"BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
container_name = "container-name"
connection_string = (
"DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;"
"AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsu"
"Fq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;"
"BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
)


def test_init_compat(dvc):
url = (
"azure://ContainerName={container_name};{connection_string}"
).format(
container_name=container_name, connection_string=connection_string,
)
config = {"url": url}
remote = RemoteAZURE(dvc, config)
assert remote.path_info == "azure://" + container_name
assert remote.connection_string == connection_string

def test_init_compat(self):
url = (
"azure://ContainerName={container_name};{connection_string}"
).format(
container_name=self.container_name,
connection_string=self.connection_string,
)
config = {"url": url}
remote = RemoteAZURE(None, config)
self.assertEqual(remote.path_info, "azure://" + self.container_name)
self.assertEqual(remote.connection_string, self.connection_string)

def test_init(self):
prefix = "some/prefix"
url = "azure://{}/{}".format(self.container_name, prefix)
config = {"url": url, "connection_string": self.connection_string}
remote = RemoteAZURE(None, config)
self.assertEqual(remote.path_info, url)
self.assertEqual(remote.connection_string, self.connection_string)
def test_init(dvc):
prefix = "some/prefix"
url = "azure://{}/{}".format(container_name, prefix)
config = {"url": url, "connection_string": connection_string}
remote = RemoteAZURE(dvc, config)
assert remote.path_info == url
assert remote.connection_string == connection_string
57 changes: 25 additions & 32 deletions tests/unit/remote/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest import TestCase

import math
import mock
import pytest

from dvc.path_info import PathInfo
from dvc.remote.base import RemoteBASE
Expand All @@ -17,42 +16,36 @@ def __eq__(self, other):


CallableOrNone = _CallableOrNone()
REMOTE_CLS = RemoteBASE


class TestRemoteBASE(object):
REMOTE_CLS = RemoteBASE


class TestMissingDeps(TestCase, TestRemoteBASE):
def test(self):
requires = {"missing": "missing"}
with mock.patch.object(self.REMOTE_CLS, "REQUIRES", requires):
with self.assertRaises(RemoteMissingDepsError):
self.REMOTE_CLS(None, {})
def test_missing_deps(dvc):
requires = {"missing": "missing"}
with mock.patch.object(REMOTE_CLS, "REQUIRES", requires):
with pytest.raises(RemoteMissingDepsError):
REMOTE_CLS(dvc, {})


class TestCmdError(TestCase, TestRemoteBASE):
def test(self):
repo = None
config = {}
def test_cmd_error(dvc):
config = {}

cmd = "sed 'hello'"
ret = "1"
err = "sed: expression #1, char 2: extra characters after command"
cmd = "sed 'hello'"
ret = "1"
err = "sed: expression #1, char 2: extra characters after command"

with mock.patch.object(
self.REMOTE_CLS,
"remove",
side_effect=RemoteCmdError("base", cmd, ret, err),
):
with self.assertRaises(RemoteCmdError):
self.REMOTE_CLS(repo, config).remove("file")
with mock.patch.object(
REMOTE_CLS,
"remove",
side_effect=RemoteCmdError("base", cmd, ret, err),
):
with pytest.raises(RemoteCmdError):
REMOTE_CLS(dvc, config).remove("file")


@mock.patch.object(RemoteBASE, "_cache_checksums_traverse")
@mock.patch.object(RemoteBASE, "_cache_object_exists")
def test_cache_exists(object_exists, traverse):
remote = RemoteBASE(None, {})
def test_cache_exists(object_exists, traverse, dvc):
remote = RemoteBASE(dvc, {})

# remote does not support traverse
remote.CAN_TRAVERSE = False
Expand Down Expand Up @@ -110,8 +103,8 @@ def test_cache_exists(object_exists, traverse):
@mock.patch.object(
RemoteBASE, "path_to_checksum", side_effect=lambda x: x,
)
def test_cache_checksums_traverse(path_to_checksum, cache_checksums):
remote = RemoteBASE(None, {})
def test_cache_checksums_traverse(path_to_checksum, cache_checksums, dvc):
remote = RemoteBASE(dvc, {})
remote.path_info = PathInfo("foo")

# parallel traverse
Expand All @@ -135,8 +128,8 @@ def test_cache_checksums_traverse(path_to_checksum, cache_checksums):
)


def test_cache_checksums():
remote = RemoteBASE(None, {})
def test_cache_checksums(dvc):
remote = RemoteBASE(dvc, {})
remote.path_info = PathInfo("foo")

with mock.patch.object(
Expand Down
16 changes: 5 additions & 11 deletions tests/unit/remote/test_gdrive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import os

from dvc.config import Config
from dvc.remote.gdrive import (
RemoteGDrive,
GDriveAccessTokenRefreshError,
Expand All @@ -13,32 +12,27 @@
USER_CREDS_MISSED_KEY_ERROR = "{}"


class Repo(object):
tmp_dir = ""
config = Config()


class TestRemoteGDrive(object):
CONFIG = {
"url": "gdrive://root/data",
"gdrive_client_id": "client",
"gdrive_client_secret": "secret",
}

def test_init(self):
remote = RemoteGDrive(Repo(), self.CONFIG)
def test_init(self, dvc):
remote = RemoteGDrive(dvc, self.CONFIG)
assert str(remote.path_info) == self.CONFIG["url"]

def test_drive(self):
remote = RemoteGDrive(Repo(), self.CONFIG)
def test_drive(self, dvc):
remote = RemoteGDrive(dvc, self.CONFIG)
os.environ[
RemoteGDrive.GDRIVE_CREDENTIALS_DATA
] = USER_CREDS_TOKEN_REFRESH_ERROR
with pytest.raises(GDriveAccessTokenRefreshError):
remote._drive

os.environ[RemoteGDrive.GDRIVE_CREDENTIALS_DATA] = ""
remote = RemoteGDrive(Repo(), self.CONFIG)
remote = RemoteGDrive(dvc, self.CONFIG)
os.environ[
RemoteGDrive.GDRIVE_CREDENTIALS_DATA
] = USER_CREDS_MISSED_KEY_ERROR
Expand Down
Loading

0 comments on commit dfe184f

Please sign in to comment.