Skip to content

Commit

Permalink
feat: change --selector argument type
Browse files Browse the repository at this point in the history
  • Loading branch information
yshalenyk committed Jul 3, 2024
1 parent 916da54 commit b1c9d8c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
19 changes: 13 additions & 6 deletions nightingale/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def load_config(config_file):
@click.option("--datasource", type=str, help="Datasource connection string")
@click.option("--mapping-file", type=click_pathlib.Path(exists=True), help="Mapping file path")
@click.option("--ocid-prefix", type=str, help="OCID prefix")
@click.option("--selector", type=str, help="Selector")
@click.option("--selector", type=click_pathlib.Path(exists=True), help="Path to selector SQL script")
@click.option("--force-publish", is_flag=True, help="Force publish")
@click.option("--publisher", type=str, help="Publisher")
@click.option("--base-uri", type=str, help="Base URI")
@click.option("--version", type=str, help="Version")
@click.option("--publisher", type=str, help="Publisher name")
@click.option("--base-uri", type=str, help="Package base URI")
@click.option("--version", type=str, help="OCDS Version")
@click.option("--publisher-uid", type=str, help="Publisher UID")
@click.option("--publisher-scheme", type=str, help="Publisher scheme")
@click.option("--publisher-uri", type=str, help="Publisher URI")
@click.option("--extensions", type=str, multiple=True, help="Extensions")
@click.option("--extensions", type=str, multiple=True, help="List of extensions")
@click.option("--output-directory", type=click_pathlib.Path(exists=True), help="Output directory")
def run(
config_file,
Expand Down Expand Up @@ -106,10 +106,17 @@ def run(
if datasource:
config_data["datasource"] = {"connection": datasource}
if mapping_file or ocid_prefix or selector or force_publish:
selector_content = config_data["mapping"]["selector"]
if selector:
try:
with open(selector, "r") as f:
selector_content = f.read()
except (OSError, IOError) as e:
raise click.ClickException(f"Error reading selector file {selector}: {e}")
config_data["mapping"] = {
"file": mapping_file or config_data["mapping"]["file"],
"ocid_prefix": ocid_prefix or config_data["mapping"]["ocid_prefix"],
"selector": selector or config_data["mapping"]["selector"],
"selector": selector_content,
"force_publish": force_publish
if force_publish is not None
else config_data["mapping"].get("force_publish", False),
Expand Down
32 changes: 32 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.config_path = Path(self.temp_dir.name) / "test_config.toml"
self.invalid_config_path = Path(self.temp_dir.name) / "invalid_test_config.toml"
self.selector_data = "SELECT * FROM test_table;"
self.selector_path = Path(self.temp_dir.name) / "test_selector.sql"
with open(self.selector_path, "w") as f:
f.write(self.selector_data)
self.config_data = """
[datasource]
connection = "test_connection"
Expand Down Expand Up @@ -143,6 +147,34 @@ def test_invalid_toml_file(self):
self.assertNotEqual(result.exit_code, 0)
self.assertIn("Error decoding TOML", result.output)

@patch("nightingale.cli.Config.from_file")
@patch("nightingale.cli.OCDSDataMapper")
@patch("nightingale.cli.DataLoader")
@patch("nightingale.cli.DataWriter")
def test_run_with_selector_file(self, mock_writer, mock_loader, mock_mapper, mock_config):
# Setup mocks
mock_config.return_value = MagicMock()
mock_mapper_instance = MagicMock()
mock_mapper.return_value = mock_mapper_instance

mock_loader_instance = MagicMock()
mock_loader.return_value = mock_loader_instance

mock_writer_instance = MagicMock()
mock_writer.return_value = mock_writer_instance

mock_mapper_instance.map.return_value = [{"dummy_data": "data"}]

result = self.runner.invoke(
run, ["--config", str(self.config_path), "--selector", str(self.selector_path), "--loglevel", "INFO"]
)

self.assertEqual(result.exit_code, 0)
mock_mapper.assert_called_once()
mock_loader.assert_called_once()
mock_writer.assert_called_once()
mock_writer_instance.write.assert_called_once_with([{"dummy_data": "data"}])


if __name__ == "__main__":
unittest.main()

0 comments on commit b1c9d8c

Please sign in to comment.