Skip to content

Commit

Permalink
Use pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
kingsleyadam committed Sep 24, 2024
1 parent 5c77650 commit 4281d9a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ ignore = [
"Q",
"COM812",
"COM819",
"ISC001",
"PT" # Use Python built-in unittest vs pytest
"ISC001"
]

[tool.ruff.lint.flake8-pytest-style]
Expand Down
38 changes: 18 additions & 20 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Unit test for Base device."""
"""Test class to test the Base device."""

import unittest
from unittest.mock import MagicMock

import pytest

from abbfreeathome.api import FreeAtHomeApi
from abbfreeathome.devices.base import Base
from abbfreeathome.exceptions import InvalidDeviceChannelPairingId


class TestBase(unittest.TestCase):
class TestBase:
"""The TestBase class for testing the Base device."""

def setUp(self):
@pytest.fixture(autouse=True)
def _setup(self):
"""Set up the test class."""
self.api = MagicMock(spec=FreeAtHomeApi)
self.device_id = "device123"
Expand Down Expand Up @@ -40,51 +42,47 @@ def setUp(self):

def test_device_id(self):
"""Test the device id."""
self.assertEqual(self.base.device_id, self.device_id)
assert self.base.device_id == self.device_id

def test_device_name(self):
"""Test the device name."""
self.assertEqual(self.base.device_name, self.device_name)
assert self.base.device_name == self.device_name

def test_channel_id(self):
"""Test the channel id."""
self.assertEqual(self.base.channel_id, self.channel_id)
assert self.base.channel_id == self.channel_id

def test_channel_name(self):
"""Test the channel name."""
self.assertEqual(self.base.channel_name, self.channel_name)
assert self.base.channel_name == self.channel_name

def test_get_input_by_pairing_id(self):
"""Test the get_input_by_paring_id function."""
input_id, value = self.base.get_input_by_pairing_id(1)
self.assertEqual(input_id, "input1")
self.assertEqual(value, "input_value1")
assert input_id == "input1"
assert value == "input_value1"

with self.assertRaises(InvalidDeviceChannelPairingId):
with pytest.raises(InvalidDeviceChannelPairingId):
self.base.get_input_by_pairing_id(99)

def test_get_output_by_pairing_id(self):
"""Test the get_output_by_pairing_id function."""
output_id, value = self.base.get_output_by_pairing_id(1)
self.assertEqual(output_id, "output1")
self.assertEqual(value, "output_value1")
assert output_id == "output1"
assert value == "output_value1"

with self.assertRaises(InvalidDeviceChannelPairingId):
with pytest.raises(InvalidDeviceChannelPairingId):
self.base.get_output_by_pairing_id(99)

def test_register_callback(self):
"""Test register a callback."""
callback = MagicMock()
self.base.register_callback(callback)
self.assertIn(callback, self.base._callbacks) # noqa: SLF001
assert callback in self.base._callbacks # noqa: SLF001

def test_remove_callback(self):
"""Test removing a callback."""
callback = MagicMock()
self.base.register_callback(callback)
self.base.remove_callback(callback)
self.assertNotIn(callback, self.base._callbacks) # noqa: SLF001


if __name__ == "__main__":
unittest.main()
assert callback not in self.base._callbacks # noqa: SLF001

0 comments on commit 4281d9a

Please sign in to comment.