Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ronpal committed Nov 1, 2024
1 parent 956f8e4 commit 9632256
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 3 deletions.
15 changes: 13 additions & 2 deletions cognite/client/data_classes/hosted_extractors/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,17 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:
return output


@dataclass
class BasicAuthenticationWrite(AuthenticationWrite):
_type = "basic"
username: str
password: str

@classmethod
def _load_authentication(cls, resource: dict[str, Any]) -> Self:
return cls(username=resource["username"], password=resource["password"])


@dataclass
class RESTHeaderAuthenticationWrite(AuthenticationWrite):
_type = "header"
Expand Down Expand Up @@ -1000,7 +1011,7 @@ class RESTClientCredentialsAuthentication(Authentication):
client_secret: str
tokenUrl: str
scopes: str
defaultExpiresIn: str
defaultExpiresIn: str | None

@classmethod
def _load_authentication(cls, resource: dict[str, Any]) -> Self:
Expand All @@ -1009,7 +1020,7 @@ def _load_authentication(cls, resource: dict[str, Any]) -> Self:
client_secret=resource["clientSecret"],
tokenUrl=resource["tokenUrl"],
scopes=resource["scopes"],
defaultExpiresIn=resource["defaultExpiresIn"],
defaultExpiresIn=resource.get("defaultExpiresIn"),
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
import textwrap

from cognite.client.data_classes.hosted_extractors import RestSourceWrite, SourceWrite
import pytest

from cognite.client.data_classes.hosted_extractors.sources import (
_SOURCE_CLASS_BY_TYPE,
_SOURCE_WRITE_CLASS_BY_TYPE,
BasicAuthentication,
BasicAuthenticationWrite,
RESTClientCredentialsAuthentication,
RESTClientCredentialsAuthenticationWrite,
RESTHeaderAuthentication,
RESTHeaderAuthenticationWrite,
RESTQueryAuthentication,
RESTQueryAuthenticationWrite,
RestSourceWrite,
Source,
SourceWrite,
)


class TestSource:
Expand All @@ -20,3 +36,93 @@ def test_load_yaml_set_default_scheme(self) -> None:
loaded = SourceWrite.load(raw_yaml)
assert isinstance(loaded, RestSourceWrite)
assert loaded.scheme == "https"


@pytest.fixture(
params=[
(
{
"source": "mqtt3",
"externalId": "mqtt-source",
"host": "mqtt-broker",
"port": 1883,
"authentication": {"type": "basic", "username": "user", "password": "pass"},
},
BasicAuthentication,
BasicAuthenticationWrite,
),
(
{
"source": "rest",
"externalId": "rest-source",
"host": "rest-host",
"port": 443,
"authentication": {"type": "basic", "username": "user", "password": "pass"},
},
BasicAuthentication,
BasicAuthenticationWrite,
),
(
{
"source": "rest",
"externalId": "rest-source",
"host": "rest-host",
"port": 443,
"authentication": {"type": "header", "key": "key", "value": "value"},
},
RESTHeaderAuthentication,
RESTHeaderAuthenticationWrite,
),
(
{
"source": "rest",
"externalId": "rest-source",
"host": "rest-host",
"port": 443,
"authentication": {"type": "query", "key": "key", "value": "value"},
},
RESTQueryAuthentication,
RESTQueryAuthenticationWrite,
),
(
{
"source": "rest",
"externalId": "rest-source",
"host": "rest-host",
"port": 443,
"authentication": {
"type": "clientCredentials",
"clientId": "client-id",
"clientSecret": "client-secret",
"tokenUrl": "https://token.url",
"scopes": ["scope1", "scope2"],
},
},
RESTClientCredentialsAuthentication,
RESTClientCredentialsAuthenticationWrite,
),
]
)
def sample_sources(request):
return request.param


def test_auth_loaders_auth_cls(sample_sources):
resource, expected_auth_cls, expected_auth_write_cls = sample_sources

source_cls = _SOURCE_CLASS_BY_TYPE.get(resource["source"])
resource["createdTime"] = "1970-01-01T00:00:00Z"
resource["lastUpdatedTime"] = "1970-01-01T00:00:01Z"
if resource.get("port", "") == 443:
resource["scheme"] = "https"

obj: Source = source_cls._load(resource=resource)
assert isinstance(obj.authentication, expected_auth_cls)


def test_auth_loaders(sample_sources) -> None:
resource, expected_auth_cls, expected_auth_write_cls = sample_sources

source_write_cls = _SOURCE_WRITE_CLASS_BY_TYPE.get(resource["source"])
obj: SourceWrite = source_write_cls._load(resource=resource)
assert isinstance(obj.authentication, expected_auth_write_cls)

0 comments on commit 9632256

Please sign in to comment.