diff --git a/RELEASE.md b/RELEASE.md index 1db7d8c920..c287a896a8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -5,6 +5,11 @@ Please follow the established format: - Use present tense (e.g. 'Add new feature') - Include the ID number for the related PR (or PRs) in parentheses --> +## Major features and improvements + +- Add support for displaying dataset statistics in the metadata panel. (#1472) + +# Release 6.3.5 ## Bug fixes and other changes diff --git a/cypress/tests/ui/flowchart/flowchart.cy.js b/cypress/tests/ui/flowchart/flowchart.cy.js index c498a67ae9..bd0e0e3de4 100644 --- a/cypress/tests/ui/flowchart/flowchart.cy.js +++ b/cypress/tests/ui/flowchart/flowchart.cy.js @@ -192,4 +192,28 @@ describe('Flowchart DAG', () => { .should('exist') .and('have.text', `Oops, there's nothing to see here`); }); + + it('verifies that users can open and see the dataset statistics in the metadata panel for datasets. #TC-51', () => { + const dataNodeText = 'Companies'; + + // Assert before action + cy.get('[data-label="Dataset statistics:]').should('not.exist'); + + // Action + cy.get('.pipeline-node > .pipeline-node__text') + .contains(dataNodeText) + .click({ force: true }); + + // Assert after action + cy.get('[data-label="Dataset statistics:"]').should('exist'); + cy.get('[data-test=stats-value-rows]') + .invoke('text') + .should((rowsValue) => expect(rowsValue).to.be.eq('77,096')); + cy.get('[data-test=stats-value-columns]') + .invoke('text') + .should((colsValue) => expect(parseInt(colsValue)).to.be.eq(5)); + cy.get('[data-test=stats-value-file_size]') + .invoke('text') + .should((fileSizeValue) => expect(fileSizeValue).to.be.eq('1.8MB')); + }); }); diff --git a/demo-project/stats.json b/demo-project/stats.json new file mode 100644 index 0000000000..34723c9a86 --- /dev/null +++ b/demo-project/stats.json @@ -0,0 +1,39 @@ +{ + "companies": { "rows": 77096, "columns": 5, "file_size": 1810602 }, + "ingestion.int_typed_companies": { + "rows": 77096, + "columns": 5, + "file_size": 550616 + }, + "reviews": { "rows": 77096, "columns": 10, "file_size": 2937144 }, + "ingestion.int_typed_reviews": { + "rows": 55790, + "columns": 11, + "file_size": 1335600 + }, + "shuttles": { "rows": 77096, "columns": 13, "file_size": 4195290 }, + "ingestion.int_typed_shuttles": { + "rows": 77096, + "columns": 13, + "file_size": 1235685 + }, + "ingestion.prm_agg_companies": { "rows": 50098, "columns": 5 }, + "prm_shuttle_company_reviews": { + "rows": 29768, + "columns": 27, + "file_size": 1020356 + }, + "prm_spine_table": { "rows": 29768, "columns": 3, "file_size": 655994 }, + "feature_engineering.feat_derived_features": { "rows": 29768, "columns": 3 }, + "feature_importance_output": { "rows": 15, "columns": 2, "file_size": 460 }, + "feature_engineering.feat_static_features": { "rows": 29768, "columns": 12 }, + "ingestion.prm_spine_table_clone": { "rows": 29768, "columns": 3 }, + "reporting.cancellation_policy_breakdown": { + "rows": 21, + "columns": 3, + "file_size": 8744 + }, + "model_input_table": { "rows": 29768, "columns": 12, "file_size": 787351 }, + "X_train": { "rows": 23814, "columns": 11 }, + "X_test": { "rows": 5954, "columns": 11 } +} diff --git a/package/kedro_viz/api/rest/responses.py b/package/kedro_viz/api/rest/responses.py index 22ece3711f..1032ad3436 100644 --- a/package/kedro_viz/api/rest/responses.py +++ b/package/kedro_viz/api/rest/responses.py @@ -114,6 +114,7 @@ class DataNodeMetadataAPIResponse(BaseAPIResponse): tracking_data: Optional[Dict] run_command: Optional[str] preview: Optional[Dict] + stats: Optional[Dict] class Config: schema_extra = { @@ -130,6 +131,7 @@ class TranscodedDataNodeMetadataAPIReponse(BaseAPIResponse): original_type: str transcoded_types: List[str] run_command: Optional[str] + stats: Optional[Dict] class ParametersNodeMetadataAPIResponse(BaseAPIResponse): diff --git a/package/kedro_viz/api/rest/router.py b/package/kedro_viz/api/rest/router.py index 7a89383065..1b48be2f9d 100644 --- a/package/kedro_viz/api/rest/router.py +++ b/package/kedro_viz/api/rest/router.py @@ -49,10 +49,12 @@ async def get_single_node_metadata(node_id: str): return TaskNodeMetadata(node) if isinstance(node, DataNode): - return DataNodeMetadata(node) + dataset_stats = data_access_manager.get_stats_for_data_node(node) + return DataNodeMetadata(node, dataset_stats) if isinstance(node, TranscodedDataNode): - return TranscodedDataNodeMetadata(node) + dataset_stats = data_access_manager.get_stats_for_data_node(node) + return TranscodedDataNodeMetadata(node, dataset_stats) return ParametersNodeMetadata(node) diff --git a/package/kedro_viz/data_access/managers.py b/package/kedro_viz/data_access/managers.py index cae46e7e8e..cae19cc80c 100644 --- a/package/kedro_viz/data_access/managers.py +++ b/package/kedro_viz/data_access/managers.py @@ -62,6 +62,7 @@ def __init__(self): ) self.runs = RunsRepository() self.tracking_datasets = TrackingDatasetsRepository() + self.dataset_stats = {} def set_db_session(self, db_session_class: sessionmaker): """Set db session on repositories that need it.""" @@ -91,6 +92,28 @@ def add_pipelines(self, pipelines: Dict[str, KedroPipeline]): # Add the registered pipeline and its components to their repositories self.add_pipeline(registered_pipeline_id, pipeline) + def add_dataset_stats(self, stats_dict: Dict): + """Add dataset statistics (eg. rows, columns, file_size) as a dictionary. + This will help in showing the relevant stats in the metadata panel + + Args: + stats_dict: A dictionary object loaded from stats.json file in the kedro project + """ + + self.dataset_stats = stats_dict + + def get_stats_for_data_node( + self, data_node: Union[DataNode, TranscodedDataNode] + ) -> Dict: + """Returns the dataset statistics for the data node if found else returns an + empty dictionary + + Args: + The data node for which we need the statistics + """ + + return self.dataset_stats.get(data_node.name, {}) + def add_pipeline(self, registered_pipeline_id: str, pipeline: KedroPipeline): """Iterate through all the nodes and datasets in a "registered" pipeline and add them to relevant repositories. Take care of extracting other relevant information diff --git a/package/kedro_viz/integrations/kedro/data_loader.py b/package/kedro_viz/integrations/kedro/data_loader.py index 69c9f74888..4b78b44d95 100644 --- a/package/kedro_viz/integrations/kedro/data_loader.py +++ b/package/kedro_viz/integrations/kedro/data_loader.py @@ -6,6 +6,8 @@ # pylint: disable=missing-function-docstring, no-else-return import base64 +import json +import logging from pathlib import Path from typing import Any, Dict, Optional, Tuple @@ -14,23 +16,25 @@ try: from kedro_datasets import ( # isort:skip - json, + json as json_dataset, matplotlib, plotly, tracking, ) except ImportError: from kedro.extras.datasets import ( # Safe since ImportErrors are suppressed within kedro. - json, + json as json_dataset, matplotlib, plotly, tracking, ) + from kedro.io import DataCatalog from kedro.io.core import get_filepath_str from kedro.pipeline import Pipeline from semver import VersionInfo +logger = logging.getLogger(__name__) KEDRO_VERSION = VersionInfo.parse(__version__) @@ -54,11 +58,37 @@ def _bootstrap(project_path: Path): return +def get_dataset_stats(project_path: Path) -> Dict: + """Return the stats saved at stats.json as a dictionary if found. + If not, return an empty dictionary + + Args: + project_path: the path where the Kedro project is located. + """ + try: + stats_file_path = project_path / "stats.json" + + if not stats_file_path.exists(): + return {} + + with open(stats_file_path, encoding="utf8") as stats_file: + stats = json.load(stats_file) + return stats + + except Exception as exc: # pylint: disable=broad-exception-caught + logger.warning( + "Unable to get dataset statistics from project path %s : %s", + project_path, + exc, + ) + return {} + + def load_data( project_path: Path, env: Optional[str] = None, extra_params: Optional[Dict[str, Any]] = None, -) -> Tuple[DataCatalog, Dict[str, Pipeline], BaseSessionStore]: +) -> Tuple[DataCatalog, Dict[str, Pipeline], BaseSessionStore, Dict]: """Load data from a Kedro project. Args: project_path: the path whether the Kedro project is located. @@ -91,7 +121,9 @@ def load_data( # in case user doesn't have an active session down the line when it's first accessed. # Useful for users who have `get_current_session` in their `register_pipelines()`. pipelines_dict = dict(pipelines) - return catalog, pipelines_dict, session_store + stats_dict = get_dataset_stats(project_path) + + return catalog, pipelines_dict, session_store, stats_dict elif KEDRO_VERSION.match(">=0.17.1"): from kedro.framework.session import KedroSession @@ -103,8 +135,9 @@ def load_data( ) as session: context = session.load_context() session_store = session._store + stats_dict = get_dataset_stats(project_path) - return context.catalog, context.pipelines, session_store + return context.catalog, context.pipelines, session_store, stats_dict else: # Since Viz is only compatible with kedro>=0.17.0, this just matches 0.17.0 from kedro.framework.session import KedroSession @@ -120,8 +153,9 @@ def load_data( ) as session: context = session.load_context() session_store = session._store + stats_dict = get_dataset_stats(project_path) - return context.catalog, context.pipelines, session_store + return context.catalog, context.pipelines, session_store, stats_dict # The dataset type is available as an attribute if and only if the import from kedro @@ -140,13 +174,13 @@ def matplotlib_writer_load(dataset: matplotlib.MatplotlibWriter) -> str: matplotlib.MatplotlibWriter._load = matplotlib_writer_load if hasattr(plotly, "JSONDataSet"): - plotly.JSONDataSet._load = json.JSONDataSet._load + plotly.JSONDataSet._load = json_dataset.JSONDataSet._load if hasattr(plotly, "PlotlyDataSet"): - plotly.PlotlyDataSet._load = json.JSONDataSet._load + plotly.PlotlyDataSet._load = json_dataset.JSONDataSet._load if hasattr(tracking, "JSONDataSet"): - tracking.JSONDataSet._load = json.JSONDataSet._load + tracking.JSONDataSet._load = json_dataset.JSONDataSet._load if hasattr(tracking, "MetricsDataSet"): - tracking.MetricsDataSet._load = json.JSONDataSet._load + tracking.MetricsDataSet._load = json_dataset.JSONDataSet._load diff --git a/package/kedro_viz/integrations/kedro/hooks.py b/package/kedro_viz/integrations/kedro/hooks.py new file mode 100644 index 0000000000..d51291c03b --- /dev/null +++ b/package/kedro_viz/integrations/kedro/hooks.py @@ -0,0 +1,174 @@ +# pylint: disable=broad-exception-caught, protected-access +"""`kedro_viz.integrations.kedro.hooks` defines hooks to add additional +functionalities for a kedro run.""" + +import json +import logging +from collections import defaultdict +from typing import Any, Union + +from kedro.framework.hooks import hook_impl +from kedro.io import DataCatalog +from kedro.io.core import get_filepath_str +from kedro.pipeline.pipeline import TRANSCODING_SEPARATOR, _strip_transcoding + +logger = logging.getLogger(__name__) + + +class DatasetStatsHook: + """Class to collect dataset statistics during a kedro run + and save it to a JSON file. The class currently supports + (pd.DataFrame) dataset instances""" + + def __init__(self): + self._stats = defaultdict(dict) + + @hook_impl + def after_catalog_created(self, catalog: DataCatalog): + """Hooks to be invoked after a data catalog is created. + + Args: + catalog: The catalog that was created. + """ + + self.datasets = catalog._data_sets + + @hook_impl + def after_dataset_loaded(self, dataset_name: str, data: Any): + """Hook to be invoked after a dataset is loaded from the catalog. + Once the dataset is loaded, extract the required dataset statistics. + The hook currently supports (pd.DataFrame) dataset instances + + Args: + dataset_name: name of the dataset that was loaded from the catalog. + data: the actual data that was loaded from the catalog. + """ + + self.create_dataset_stats(dataset_name, data) + + @hook_impl + def after_dataset_saved(self, dataset_name: str, data: Any): + """Hook to be invoked after a dataset is saved to the catalog. + Once the dataset is saved, extract the required dataset statistics. + The hook currently supports (pd.DataFrame) dataset instances + + Args: + dataset_name: name of the dataset that was saved to the catalog. + data: the actual data that was saved to the catalog. + """ + + self.create_dataset_stats(dataset_name, data) + + @hook_impl + def after_pipeline_run(self): + """Hook to be invoked after a pipeline runs. + Once the pipeline run completes, write the dataset + statistics to stats.json file + + """ + try: + with open("stats.json", "w", encoding="utf8") as file: + sorted_stats_data = { + dataset_name: self.format_stats(stats) + for dataset_name, stats in self._stats.items() + } + json.dump(sorted_stats_data, file) + + except Exception as exc: # pragma: no cover + logger.warning( + "Unable to write dataset statistics for the pipeline: %s", exc + ) + + def create_dataset_stats(self, dataset_name: str, data: Any): + """Helper method to create dataset statistics. + Currently supports (pd.DataFrame) dataset instances. + + Args: + dataset_name: The dataset name for which we need the statistics + data: Actual data that is loaded/saved to the catalog + + """ + try: + import pandas as pd # pylint: disable=import-outside-toplevel + + stats_dataset_name = self.get_stats_dataset_name(dataset_name) + + if isinstance(data, pd.DataFrame): + self._stats[stats_dataset_name]["rows"] = int(data.shape[0]) + self._stats[stats_dataset_name]["columns"] = int(data.shape[1]) + + current_dataset = self.datasets.get(dataset_name, None) + + if current_dataset: + self._stats[stats_dataset_name]["file_size"] = self.get_file_size( + current_dataset + ) + + except ImportError as exc: # pragma: no cover + logger.warning( + "Unable to import dependencies to extract dataset statistics for %s : %s", + dataset_name, + exc, + ) + except Exception as exc: # pragma: no cover + logger.warning( + "[hook: after_dataset_saved] Unable to create statistics for the dataset %s : %s", + dataset_name, + exc, + ) + + def get_file_size(self, dataset: Any) -> Union[int, None]: + """Helper method to return the file size of a dataset + + Args: + dataset: A dataset instance for which we need the file size + + Returns: file size for the dataset if file_path is valid, if not returns None + """ + + if not (hasattr(dataset, "_filepath") and dataset._filepath): + return None + + try: + file_path = get_filepath_str(dataset._filepath, dataset._protocol) + return dataset._fs.size(file_path) + + except Exception as exc: + logger.warning( + "Unable to get file size for the dataset %s: %s", dataset, exc + ) + return None + + def format_stats(self, stats: dict) -> dict: + """Sort the stats extracted from the datasets using the sort order + + Args: + stats: A dictionary of statistics for a dataset + + Returns: A sorted dictionary based on the sort_order + """ + # Custom sort order + sort_order = ["rows", "columns", "file_size"] + return {stat: stats.get(stat) for stat in sort_order if stat in stats} + + def get_stats_dataset_name(self, dataset_name: str) -> str: + """Get the dataset name for assigning stat values in the dictionary. + If the dataset name contains transcoded information, strip the transcoding. + + Args: + dataset_name: name of the dataset + + Returns: Dataset name without any transcoding information + """ + + stats_dataset_name = dataset_name + + # Strip transcoding + is_transcoded_dataset = TRANSCODING_SEPARATOR in dataset_name + if is_transcoded_dataset: + stats_dataset_name = _strip_transcoding(dataset_name) + + return stats_dataset_name + + +dataset_stats_hook = DatasetStatsHook() diff --git a/package/kedro_viz/models/flowchart.py b/package/kedro_viz/models/flowchart.py index 8b668ead5c..bcf1d9b232 100644 --- a/package/kedro_viz/models/flowchart.py +++ b/package/kedro_viz/models/flowchart.py @@ -13,7 +13,7 @@ from kedro.pipeline.node import Node as KedroNode from kedro.pipeline.pipeline import TRANSCODING_SEPARATOR, _strip_transcoding -from .utils import get_dataset_type +from kedro_viz.models.utils import get_dataset_type try: # kedro 0.18.11 onwards @@ -541,6 +541,7 @@ class DataNodeMetadata(GraphNodeMetadata): # the underlying data node to which this metadata belongs data_node: InitVar[DataNode] + dataset_stats: InitVar[Dict] # the optional plot data if the underlying dataset has a plot. # currently only applicable for PlotlyDataSet @@ -557,12 +558,15 @@ class DataNodeMetadata(GraphNodeMetadata): preview: Optional[Dict] = field(init=False, default=None) + stats: Optional[Dict] = field(init=False, default=None) + # TODO: improve this scheme. - def __post_init__(self, data_node: DataNode): + def __post_init__(self, data_node: DataNode, dataset_stats: Dict): self.type = data_node.dataset_type dataset = cast(AbstractDataset, data_node.kedro_obj) dataset_description = dataset._describe() self.filepath = _parse_filepath(dataset_description) + self.stats = dataset_stats # Run command is only available if a node is an output, i.e. not a free input if not data_node.is_free_input: @@ -615,10 +619,15 @@ class TranscodedDataNodeMetadata(GraphNodeMetadata): transcoded_types: List[str] = field(init=False) + stats: Optional[Dict] = field(init=False, default=None) + # the underlying data node to which this metadata belongs transcoded_data_node: InitVar[TranscodedDataNode] + dataset_stats: InitVar[Dict] - def __post_init__(self, transcoded_data_node: TranscodedDataNode): + def __post_init__( + self, transcoded_data_node: TranscodedDataNode, dataset_stats: Dict + ): original_version = transcoded_data_node.original_version self.original_type = get_dataset_type(original_version) @@ -629,6 +638,7 @@ def __post_init__(self, transcoded_data_node: TranscodedDataNode): dataset_description = original_version._describe() self.filepath = _parse_filepath(dataset_description) + self.stats = dataset_stats if not transcoded_data_node.is_free_input: self.run_command = ( diff --git a/package/kedro_viz/models/utils.py b/package/kedro_viz/models/utils.py index cf5722c499..aca687a393 100644 --- a/package/kedro_viz/models/utils.py +++ b/package/kedro_viz/models/utils.py @@ -1,11 +1,15 @@ """`kedro_viz.models.utils` contains utility functions used in the `kedro_viz.models` package""" +import logging from typing import TYPE_CHECKING +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: - try: + try: # pragma: no cover # kedro 0.18.12 onwards from kedro.io.core import AbstractDataset - except ImportError: + except ImportError: # pragma: no cover # older versions from kedro.io.core import AbstractDataSet as AbstractDataset diff --git a/package/kedro_viz/server.py b/package/kedro_viz/server.py index c5cc96df88..09c343074c 100644 --- a/package/kedro_viz/server.py +++ b/package/kedro_viz/server.py @@ -1,4 +1,5 @@ -"""`kedro_viz.server` provides utilities to launch a webserver for Kedro pipeline visualisation.""" +"""`kedro_viz.server` provides utilities to launch a webserver +for Kedro pipeline visualisation.""" import webbrowser from pathlib import Path from typing import Any, Dict, Optional @@ -31,6 +32,7 @@ def populate_data( catalog: DataCatalog, pipelines: Dict[str, Pipeline], session_store: BaseSessionStore, + stats_dict: Dict, ): # pylint: disable=redefined-outer-name """Populate data repositories. Should be called once on application start if creating an api app from project. @@ -43,6 +45,7 @@ def populate_data( data_access_manager.add_catalog(catalog) data_access_manager.add_pipelines(pipelines) + data_access_manager.add_dataset_stats(stats_dict) def run_server( @@ -80,7 +83,7 @@ def run_server( """ if load_file is None: path = Path(project_path) if project_path else Path.cwd() - catalog, pipelines, session_store = kedro_data_loader.load_data( + catalog, pipelines, session_store, stats_dict = kedro_data_loader.load_data( path, env, extra_params ) pipelines = ( @@ -88,7 +91,9 @@ def run_server( if pipeline_name is None else {pipeline_name: pipelines[pipeline_name]} ) - populate_data(data_access_manager, catalog, pipelines, session_store) + populate_data( + data_access_manager, catalog, pipelines, session_store, stats_dict + ) if save_file: default_response = get_default_response() jsonable_default_response = jsonable_encoder(default_response) diff --git a/package/setup.py b/package/setup.py index 9d5bcee628..19245b748c 100644 --- a/package/setup.py +++ b/package/setup.py @@ -46,5 +46,8 @@ entry_points={ "kedro.global_commands": ["kedro-viz = kedro_viz.launchers.cli:commands"], "kedro.line_magic": ["line_magic = kedro_viz.launchers.jupyter:run_viz"], + "kedro.hooks": [ + "plugin_name = kedro_viz.integrations.kedro.hooks:dataset_stats_hook" + ], }, ) diff --git a/package/tests/conftest.py b/package/tests/conftest.py index a4315918db..8da6a23565 100644 --- a/package/tests/conftest.py +++ b/package/tests/conftest.py @@ -3,6 +3,7 @@ from typing import Dict from unittest import mock +import pandas as pd import pytest from fastapi.testclient import TestClient from kedro.framework.session.store import BaseSessionStore @@ -13,6 +14,7 @@ from kedro_viz.api import apps from kedro_viz.data_access import DataAccessManager +from kedro_viz.integrations.kedro.hooks import DatasetStatsHook from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore from kedro_viz.server import populate_data @@ -40,6 +42,16 @@ def sqlite_session_store(tmp_path): yield SQLiteStore(tmp_path, "dummy_session_id") +@pytest.fixture +def example_stats_dict(): + yield { + "companies": {"rows": 77096, "columns": 5}, + "reviews": {"rows": 77096, "columns": 10}, + "shuttles": {"rows": 77096, "columns": 13}, + "model_inputs": {"rows": 29768, "columns": 12}, + } + + @pytest.fixture def example_pipelines(): def process_data(raw_data, train_test_split): @@ -157,11 +169,16 @@ def example_api( example_pipelines: Dict[str, Pipeline], example_catalog: DataCatalog, session_store: BaseSessionStore, + example_stats_dict: Dict, mocker, ): api = apps.create_api_app_from_project(mock.MagicMock()) populate_data( - data_access_manager, example_catalog, example_pipelines, session_store + data_access_manager, + example_catalog, + example_pipelines, + session_store, + example_stats_dict, ) mocker.patch( "kedro_viz.api.rest.responses.data_access_manager", new=data_access_manager @@ -183,7 +200,7 @@ def example_api_no_default_pipeline( del example_pipelines["__default__"] api = apps.create_api_app_from_project(mock.MagicMock()) populate_data( - data_access_manager, example_catalog, example_pipelines, session_store + data_access_manager, example_catalog, example_pipelines, session_store, {} ) mocker.patch( "kedro_viz.api.rest.responses.data_access_manager", new=data_access_manager @@ -208,6 +225,7 @@ def example_transcoded_api( example_transcoded_catalog, example_transcoded_pipelines, session_store, + {}, ) mocker.patch( "kedro_viz.api.rest.responses.data_access_manager", new=data_access_manager @@ -255,3 +273,30 @@ def json(self): return self.data return MockHTTPResponse + + +@pytest.fixture +def example_data_frame(): + data = { + "id": ["35029", "30292"], + "company_rating": ["100%", "67%"], + "company_location": ["Niue", "Anguilla"], + "total_fleet_count": ["4.0", "6.0"], + "iata_approved": ["f", "f"], + } + yield pd.DataFrame(data) + + +@pytest.fixture +def example_dataset_stats_hook_obj(): + # Create an instance of DatasetStatsHook + yield DatasetStatsHook() + + +@pytest.fixture +def example_csv_dataset(tmp_path, example_data_frame): + new_csv_dataset = pandas.CSVDataSet( + filepath=Path(tmp_path / "model_inputs.csv").as_posix(), + ) + new_csv_dataset.save(example_data_frame) + yield new_csv_dataset diff --git a/package/tests/test_api/test_rest/test_responses.py b/package/tests/test_api/test_rest/test_responses.py index 9d9001846d..16b7a1446b 100644 --- a/package/tests/test_api/test_rest/test_responses.py +++ b/package/tests/test_api/test_rest/test_responses.py @@ -572,6 +572,7 @@ def test_transcoded_data_node_metadata(self, example_transcoded_api): "pandas.parquet_dataset.ParquetDataSet", ], "run_command": "kedro run --to-outputs=model_inputs@pandas2", + "stats": {}, } @@ -605,6 +606,7 @@ def test_data_node_metadata(self, client): "filepath": "model_inputs.csv", "type": "pandas.csv_dataset.CSVDataSet", "run_command": "kedro run --to-outputs=model_inputs", + "stats": {"columns": 12, "rows": 29768}, } def test_data_node_metadata_for_free_input(self, client): @@ -612,6 +614,7 @@ def test_data_node_metadata_for_free_input(self, client): assert response.json() == { "filepath": "raw_data.csv", "type": "pandas.csv_dataset.CSVDataSet", + "stats": {}, } def test_parameters_node_metadata(self, client): diff --git a/package/tests/test_integrations/test_hooks.py b/package/tests/test_integrations/test_hooks.py new file mode 100644 index 0000000000..0894f914e6 --- /dev/null +++ b/package/tests/test_integrations/test_hooks.py @@ -0,0 +1,131 @@ +from collections import defaultdict +from unittest.mock import mock_open, patch + +import pytest +from kedro.io.core import get_filepath_str + +try: + # kedro 0.18.11 onwards + from kedro.io import MemoryDataset +except ImportError: + # older versions + from kedro.io import MemoryDataSet as MemoryDataset + + +def test_dataset_stats_hook_create(example_dataset_stats_hook_obj): + # Assert for an instance of defaultdict + assert hasattr(example_dataset_stats_hook_obj, "_stats") + assert isinstance(example_dataset_stats_hook_obj._stats, defaultdict) + + +def test_after_catalog_created(example_dataset_stats_hook_obj, example_catalog): + example_dataset_stats_hook_obj.after_catalog_created(example_catalog) + + # Assert for catalog creation + assert hasattr(example_dataset_stats_hook_obj, "datasets") + assert example_dataset_stats_hook_obj.datasets == example_catalog._data_sets + + +@pytest.mark.parametrize( + "dataset_name", ["companies", "companies@pandas1", "model_inputs"] +) +def test_after_dataset_loaded( + dataset_name, example_dataset_stats_hook_obj, example_catalog, example_data_frame +): + example_dataset_stats_hook_obj.after_catalog_created(example_catalog) + example_dataset_stats_hook_obj.after_dataset_loaded( + dataset_name, example_data_frame + ) + + stats_dataset_name = example_dataset_stats_hook_obj.get_stats_dataset_name( + dataset_name + ) + + assert stats_dataset_name in example_dataset_stats_hook_obj._stats + assert example_dataset_stats_hook_obj._stats[stats_dataset_name]["rows"] == int( + example_data_frame.shape[0] + ) + assert example_dataset_stats_hook_obj._stats[stats_dataset_name]["columns"] == int( + example_data_frame.shape[1] + ) + + +@pytest.mark.parametrize("dataset_name", ["model_inputs"]) +def test_after_dataset_saved( + dataset_name, + mocker, + example_dataset_stats_hook_obj, + example_catalog, + example_data_frame, +): + example_dataset_stats_hook_obj.after_catalog_created(example_catalog) + + # Create a mock object for the 'get_file_size' function + mock_get_file_size = mocker.Mock() + + # Replace the original 'get_file_size' function with the mock + mocker.patch( + "kedro_viz.integrations.kedro.hooks.DatasetStatsHook.get_file_size", + new=mock_get_file_size, + ) + + # Set the return value of the mock + mock_get_file_size.return_value = 10 + + example_dataset_stats_hook_obj.after_dataset_saved(dataset_name, example_data_frame) + + stats_dataset_name = example_dataset_stats_hook_obj.get_stats_dataset_name( + dataset_name + ) + + assert stats_dataset_name in example_dataset_stats_hook_obj._stats + assert example_dataset_stats_hook_obj._stats[stats_dataset_name]["rows"] == int( + example_data_frame.shape[0] + ) + assert example_dataset_stats_hook_obj._stats[stats_dataset_name]["columns"] == int( + example_data_frame.shape[1] + ) + assert example_dataset_stats_hook_obj._stats[stats_dataset_name]["file_size"] == 10 + + +@pytest.mark.parametrize("dataset_name", ["companies", "companies@pandas1"]) +def test_after_pipeline_run( + dataset_name, example_dataset_stats_hook_obj, example_data_frame +): + stats_dataset_name = example_dataset_stats_hook_obj.get_stats_dataset_name( + dataset_name + ) + stats_json = { + stats_dataset_name: { + "rows": int(example_data_frame.shape[0]), + "columns": int(example_data_frame.shape[1]), + } + } + # Create a mock_open context manager + with patch("builtins.open", mock_open()) as mock_file, patch( + "json.dump" + ) as mock_json_dump: + example_dataset_stats_hook_obj.after_dataset_loaded( + dataset_name, example_data_frame + ) + example_dataset_stats_hook_obj.after_pipeline_run() + + # Assert that the file was opened with the correct filename + mock_file.assert_called_once_with("stats.json", "w", encoding="utf8") + + # Assert that json.dump was called with the expected arguments + mock_json_dump.assert_called_once_with(stats_json, mock_file.return_value) + + +@pytest.mark.parametrize( + "dataset", + [MemoryDataset()], +) +def test_get_file_size(dataset, example_dataset_stats_hook_obj, example_csv_dataset): + assert example_dataset_stats_hook_obj.get_file_size(dataset) is None + file_path = get_filepath_str( + example_csv_dataset._filepath, example_csv_dataset._protocol + ) + assert example_dataset_stats_hook_obj.get_file_size( + example_csv_dataset + ) == example_csv_dataset._fs.size(file_path) diff --git a/package/tests/test_models/test_flowchart.py b/package/tests/test_models/test_flowchart.py index 54706afa32..e66a411357 100644 --- a/package/tests/test_models/test_flowchart.py +++ b/package/tests/test_models/test_flowchart.py @@ -355,10 +355,14 @@ def test_data_node_metadata(self): tags=set(), dataset=dataset, ) - data_node_metadata = DataNodeMetadata(data_node=data_node) + data_node_metadata = DataNodeMetadata( + data_node=data_node, dataset_stats={"rows": 10, "columns": 2} + ) assert data_node_metadata.type == "pandas.csv_dataset.CSVDataSet" assert data_node_metadata.filepath == "/tmp/dataset.csv" assert data_node_metadata.run_command == "kedro run --to-outputs=dataset" + assert data_node_metadata.stats["rows"] == 10 + assert data_node_metadata.stats["columns"] == 2 def test_preview_args_not_exist(self): metadata = {"kedro-viz": {"something": 3}} @@ -401,7 +405,9 @@ def test_preview_data_node_metadata(self): preview_data_node.is_tracking_node.return_value = False preview_data_node.is_preview_node.return_value = True preview_data_node.kedro_obj._preview.return_value = mock_preview_data - preview_node_metadata = DataNodeMetadata(data_node=preview_data_node) + preview_node_metadata = DataNodeMetadata( + data_node=preview_data_node, dataset_stats={} + ) assert preview_node_metadata.preview == mock_preview_data def test_preview_data_node_metadata_not_exist(self): @@ -412,7 +418,9 @@ def test_preview_data_node_metadata_not_exist(self): preview_data_node.is_tracking_node.return_value = False preview_data_node.is_preview_node.return_value = True preview_data_node.kedro_obj._preview.return_value = False - preview_node_metadata = DataNodeMetadata(data_node=preview_data_node) + preview_node_metadata = DataNodeMetadata( + data_node=preview_data_node, dataset_stats={} + ) assert preview_node_metadata.plot is None def test_transcoded_data_node_metadata(self): @@ -427,7 +435,8 @@ def test_transcoded_data_node_metadata(self): transcoded_data_node.original_version = ParquetDataSet(filepath="foo.parquet") transcoded_data_node.transcoded_versions = [CSVDataSet(filepath="foo.csv")] transcoded_data_node_metadata = TranscodedDataNodeMetadata( - transcoded_data_node=transcoded_data_node + transcoded_data_node=transcoded_data_node, + dataset_stats={"rows": 10, "columns": 2}, ) assert ( transcoded_data_node_metadata.original_type @@ -437,6 +446,8 @@ def test_transcoded_data_node_metadata(self): assert transcoded_data_node_metadata.transcoded_types == [ "pandas.csv_dataset.CSVDataSet" ] + assert transcoded_data_node_metadata.stats["rows"] == 10 + assert transcoded_data_node_metadata.stats["columns"] == 2 def test_partitioned_data_node_metadata(self): dataset = PartitionedDataset(path="partitioned/", dataset="pandas.CSVDataSet") @@ -446,7 +457,7 @@ def test_partitioned_data_node_metadata(self): tags=set(), dataset=dataset, ) - data_node_metadata = DataNodeMetadata(data_node=data_node) + data_node_metadata = DataNodeMetadata(data_node=data_node, dataset_stats={}) assert data_node_metadata.filepath == "partitioned/" # TODO: these test should ideally use a "real" catalog entry to create actual rather @@ -468,7 +479,9 @@ def test_plotly_data_node_metadata(self): plotly_data_node.is_tracking_node.return_value = False plotly_data_node.is_preview_node.return_value = False plotly_data_node.kedro_obj.load.return_value = mock_plot_data - plotly_node_metadata = DataNodeMetadata(data_node=plotly_data_node) + plotly_node_metadata = DataNodeMetadata( + data_node=plotly_data_node, dataset_stats={} + ) assert plotly_node_metadata.plot == mock_plot_data def test_plotly_data_node_dataset_not_exist(self): @@ -478,7 +491,9 @@ def test_plotly_data_node_dataset_not_exist(self): plotly_data_node.is_tracking_node.return_value = False plotly_data_node.kedro_obj.exists.return_value = False plotly_data_node.is_preview_node.return_value = False - plotly_node_metadata = DataNodeMetadata(data_node=plotly_data_node) + plotly_node_metadata = DataNodeMetadata( + data_node=plotly_data_node, dataset_stats={} + ) assert plotly_node_metadata.plot is None def test_plotly_json_dataset_node_metadata(self): @@ -497,7 +512,9 @@ def test_plotly_json_dataset_node_metadata(self): plotly_json_dataset_node.is_tracking_node.return_value = False plotly_json_dataset_node.is_preview_node.return_value = False plotly_json_dataset_node.kedro_obj.load.return_value = mock_plot_data - plotly_node_metadata = DataNodeMetadata(data_node=plotly_json_dataset_node) + plotly_node_metadata = DataNodeMetadata( + data_node=plotly_json_dataset_node, dataset_stats={} + ) assert plotly_node_metadata.plot == mock_plot_data # @patch("base64.b64encode") @@ -512,7 +529,9 @@ def test_image_data_node_metadata(self): image_dataset_node.is_tracking_node.return_value = False image_dataset_node.is_preview_node.return_value = False image_dataset_node.kedro_obj.load.return_value = mock_image_data - image_node_metadata = DataNodeMetadata(data_node=image_dataset_node) + image_node_metadata = DataNodeMetadata( + data_node=image_dataset_node, dataset_stats={} + ) assert image_node_metadata.image == mock_image_data def test_image_data_node_dataset_not_exist(self): @@ -521,7 +540,9 @@ def test_image_data_node_dataset_not_exist(self): image_dataset_node.is_plot_node.return_value = False image_dataset_node.kedro_obj.exists.return_value = False image_dataset_node.is_preview_node.return_value = False - image_node_metadata = DataNodeMetadata(data_node=image_dataset_node) + image_node_metadata = DataNodeMetadata( + data_node=image_dataset_node, dataset_stats={} + ) assert image_node_metadata.image is None def test_json_data_node_metadata(self): @@ -538,7 +559,9 @@ def test_json_data_node_metadata(self): json_data_node.is_metric_node.return_value = False json_data_node.is_preview_node.return_value = False json_data_node.kedro_obj.load.return_value = mock_json_data - json_node_metadata = DataNodeMetadata(data_node=json_data_node) + json_node_metadata = DataNodeMetadata( + data_node=json_data_node, dataset_stats={} + ) assert json_node_metadata.tracking_data == mock_json_data assert json_node_metadata.plot is None @@ -549,7 +572,9 @@ def test_metrics_data_node_metadata_dataset_not_exist(self): metrics_data_node.is_metric_node.return_value = True metrics_data_node.is_preview_node.return_value = False metrics_data_node.kedro_obj.exists.return_value = False - metrics_node_metadata = DataNodeMetadata(data_node=metrics_data_node) + metrics_node_metadata = DataNodeMetadata( + data_node=metrics_data_node, dataset_stats={} + ) assert metrics_node_metadata.plot is None def test_data_node_metadata_latest_tracking_data_not_exist(self): @@ -559,7 +584,9 @@ def test_data_node_metadata_latest_tracking_data_not_exist(self): plotly_data_node.is_tracking_node.return_value = False plotly_data_node.kedro_obj.exists.return_value = False plotly_data_node.kedro_obj.exists.return_value = False - plotly_node_metadata = DataNodeMetadata(data_node=plotly_data_node) + plotly_node_metadata = DataNodeMetadata( + data_node=plotly_data_node, dataset_stats={} + ) assert plotly_node_metadata.plot is None def test_parameters_metadata_all_parameters(self): diff --git a/package/tests/test_models/test_utils.py b/package/tests/test_models/test_utils.py new file mode 100644 index 0000000000..70de7a8cd7 --- /dev/null +++ b/package/tests/test_models/test_utils.py @@ -0,0 +1,18 @@ +import pytest + +from kedro_viz.models.utils import get_dataset_type + +try: + # kedro 0.18.11 onwards + from kedro.io import MemoryDataset +except ImportError: + # older versions + from kedro.io import MemoryDataSet as MemoryDataset + + +@pytest.mark.parametrize( + "dataset,expected_type", + [(None, ""), (MemoryDataset(), "io.memory_dataset.MemoryDataset")], +) +def test_get_dataset_type(dataset, expected_type): + assert get_dataset_type(dataset) == expected_type diff --git a/package/tests/test_server.py b/package/tests/test_server.py index 6f99191dae..c9fcdead88 100644 --- a/package/tests/test_server.py +++ b/package/tests/test_server.py @@ -32,20 +32,23 @@ def patched_create_api_app_from_file(mocker): @pytest.fixture(autouse=True) -def patched_load_data(mocker, example_catalog, example_pipelines, session_store): +def patched_load_data( + mocker, example_catalog, example_pipelines, session_store, example_stats_dict +): yield mocker.patch( "kedro_viz.server.kedro_data_loader.load_data", return_value=( example_catalog, example_pipelines, session_store, + example_stats_dict, ), ) @pytest.fixture def patched_load_data_with_sqlite_session_store( - mocker, example_catalog, example_pipelines, sqlite_session_store + mocker, example_catalog, example_pipelines, sqlite_session_store, example_stats_dict ): yield mocker.patch( "kedro_viz.server.kedro_data_loader.load_data", @@ -53,6 +56,7 @@ def patched_load_data_with_sqlite_session_store( example_catalog, example_pipelines, sqlite_session_store, + example_stats_dict, ), ) diff --git a/src/components/metadata/metadata-stats.js b/src/components/metadata/metadata-stats.js new file mode 100644 index 0000000000..ffff03be6c --- /dev/null +++ b/src/components/metadata/metadata-stats.js @@ -0,0 +1,55 @@ +import React, { useState, useRef, useLayoutEffect } from 'react'; +import { formatFileSize, formatNumberWithCommas } from '../../utils'; +import { datasetStatLabels } from '../../config'; +import './styles/metadata-stats.css'; + +const MetaDataStats = ({ stats }) => { + const [hasOverflow, setHasOverflow] = useState(false); + const statsContainerRef = useRef(null); + + useLayoutEffect(() => { + const statsContainer = statsContainerRef.current; + + if (!statsContainer) { + return; + } + + const containerWidth = statsContainer.clientWidth; + const totalItemsWidth = Array.from(statsContainer.children).reduce( + (total, item) => total + item.offsetWidth, + 0 + ); + + setHasOverflow(totalItemsWidth > containerWidth); + }, []); + + return ( + + ); +}; + +export default MetaDataStats; diff --git a/src/components/metadata/metadata.js b/src/components/metadata/metadata.js index 01b618c5ff..a27898843a 100644 --- a/src/components/metadata/metadata.js +++ b/src/components/metadata/metadata.js @@ -24,6 +24,7 @@ import { } from '../../utils/hooks/use-generate-pathname'; import './styles/metadata.css'; +import MetaDataStats from './metadata-stats'; /** * Shows node meta data @@ -58,6 +59,7 @@ const MetaData = ({ const hasImage = Boolean(metadata?.image); const hasTrackingData = Boolean(metadata?.trackingData); const hasPreviewData = Boolean(metadata?.preview); + const hasStatsData = Boolean(metadata?.stats); const isMetricsTrackingDataset = nodeTypeIcon === 'metricsTracking'; const hasCode = Boolean(metadata?.code); const isTranscoded = Boolean(metadata?.originalType); @@ -232,6 +234,17 @@ const MetaData = ({ isCommand={metadata?.runCommand} /> + {hasStatsData && ( + <> + + Dataset statistics: + + + + )} {hasPlot && ( <> diff --git a/src/components/metadata/metadata.test.js b/src/components/metadata/metadata.test.js index 47fd7b61dd..7fcfc9c57e 100644 --- a/src/components/metadata/metadata.test.js +++ b/src/components/metadata/metadata.test.js @@ -9,9 +9,11 @@ import nodePlot from '../../utils/data/node_plot.mock.json'; import nodeParameters from '../../utils/data/node_parameters.mock.json'; import nodeTask from '../../utils/data/node_task.mock.json'; import nodeData from '../../utils/data/node_data.mock.json'; +import nodeDataStats from '../../utils/data/node_data_stats.mock.json'; import nodeTranscodedData from '../../utils/data/node_transcoded_data.mock.json'; import nodeMetricsData from '../../utils/data/node_metrics_data.mock.json'; import nodeJSONData from '../../utils/data/node_json_data.mock.json'; +import { formatFileSize } from '../../utils'; const modelInputDataSetNodeId = '23c94afb'; const splitDataTaskNodeId = '65d0d789'; @@ -356,6 +358,37 @@ describe('MetaData', () => { ); }); }); + + describe('when there is stats returned by the backend', () => { + it('shows the node statistics', () => { + const wrapper = mount({ + nodeId: modelInputDataSetNodeId, + mockMetadata: nodeDataStats, + }); + + expect(wrapper.find('[data-label="Dataset statistics:"]').length).toBe( + 1 + ); + expect(wrapper.find('[data-test="stats-label-rows"]').length).toBe(1); + expect(wrapper.find('[data-test="stats-label-columns"]').length).toBe( + 1 + ); + expect(wrapper.find('[data-test="stats-label-file_size"]').length).toBe( + 1 + ); + + expect( + parseInt(wrapper.find('[data-test="stats-value-rows"]').text()) + ).toEqual(nodeDataStats.stats.rows); + expect( + parseInt(wrapper.find('[data-test="stats-value-columns"]').text()) + ).toEqual(nodeDataStats.stats.columns); + expect( + wrapper.find('[data-test="stats-value-file_size"]').text() + ).toEqual(formatFileSize(nodeDataStats.stats.file_size)); + }); + }); + describe('Transcoded dataset nodes', () => { it('shows the node original type', () => { const wrapper = mount({ diff --git a/src/components/metadata/styles/metadata-stats.scss b/src/components/metadata/styles/metadata-stats.scss new file mode 100644 index 0000000000..f480dae522 --- /dev/null +++ b/src/components/metadata/styles/metadata-stats.scss @@ -0,0 +1,23 @@ +.pipeline-metadata-label__stats { + margin: 0 24px 0 8px; +} + +.pipeline-metadata-value__stats { + font-size: 18px; +} + +.stats-container__overflow { + align-items: baseline; + display: grid; + gap: 8px; + grid-template-columns: auto 1fr; + + .pipeline-metadata-value__stats { + margin-left: auto; + margin-right: 0; + } + + .pipeline-metadata-label__stats { + margin-right: auto; + } +} diff --git a/src/config.js b/src/config.js index b623ccb534..49103048b5 100644 --- a/src/config.js +++ b/src/config.js @@ -141,3 +141,5 @@ export const errorMessages = { experimentTracking: `Please check the spelling of "run_ids" or "view" or "comparison" in the URL. It may be a typo 😇`, runIds: `Please check the value of "run_ids" in the URL. Perhaps you've deleted the entity 🙈 or it may be a typo 😇`, }; + +export const datasetStatLabels = ['rows', 'columns', 'file_size']; diff --git a/src/reducers/nodes.js b/src/reducers/nodes.js index 51aac406cf..08116be8f2 100644 --- a/src/reducers/nodes.js +++ b/src/reducers/nodes.js @@ -90,6 +90,9 @@ function nodeReducer(nodeState = {}, action) { preview: Object.assign({}, nodeState.preview, { [id]: data.preview, }), + stats: Object.assign({}, nodeState.stats, { + [id]: data.stats, + }), }); } diff --git a/src/selectors/metadata.js b/src/selectors/metadata.js index ac697e8301..f2794e42fa 100644 --- a/src/selectors/metadata.js +++ b/src/selectors/metadata.js @@ -42,6 +42,7 @@ export const getClickedNodeMetaData = createSelector( (state) => state.node.transcodedTypes, (state) => state.node.runCommand, (state) => state.node.preview, + (state) => state.node.stats, (state) => state.isPrettyName, ], ( @@ -66,6 +67,7 @@ export const getClickedNodeMetaData = createSelector( nodeTranscodedTypes, nodeRunCommand, preview, + stats, isPrettyName ) => { if (!nodeId || Object.keys(nodeType).length === 0) { @@ -114,6 +116,7 @@ export const getClickedNodeMetaData = createSelector( : nodeOutputs[nodeId] && nodeOutputs[nodeId].map((nodeOutput) => stripNamespace(nodeOutput)), preview: preview && preview[nodeId], + stats: stats && stats[nodeId], }; return metadata; diff --git a/src/utils/data/node_data_stats.mock.json b/src/utils/data/node_data_stats.mock.json new file mode 100644 index 0000000000..7e6c07bca5 --- /dev/null +++ b/src/utils/data/node_data_stats.mock.json @@ -0,0 +1,10 @@ +{ + "filepath": "/tmp/project/data/03_primary/model_input_table.csv", + "type": "pandas.csv_dataset.CSVDataSet", + "run_command": "kedro run --to-outputs=model_input_table", + "stats": { + "rows": 10, + "columns": 2, + "file_size": 1100 + } +} diff --git a/src/utils/index.js b/src/utils/index.js index e85ca32757..069e235169 100644 --- a/src/utils/index.js +++ b/src/utils/index.js @@ -122,3 +122,37 @@ export const prettifyModularPipelineNames = (modularPipelines) => { } return modularPipelines; }; + +/** + * Formats file size for the dataset metadata stats + * @param {Number} fileSizeInBytes The file size in bytes + * @returns {String} The formatted file size as e.g. "1.1KB" + */ +export const formatFileSize = (fileSizeInBytes) => { + // This is to convert bytes to KB or MB. + const conversionUnit = 1000; + + if (!fileSizeInBytes) { + // dataset not configured + return 'N/A'; + } else if (fileSizeInBytes < conversionUnit) { + // Less than 1 KB + return `${fileSizeInBytes} bytes`; + } else if (fileSizeInBytes < conversionUnit * conversionUnit) { + // Less than 1 MB + const sizeInKB = fileSizeInBytes / conversionUnit; + return `${sizeInKB.toFixed(1)}KB`; + } else { + const sizeInMB = fileSizeInBytes / (conversionUnit * conversionUnit); + return `${sizeInMB.toFixed(1)}MB`; + } +}; + +/** + * Formats a number to a comma separated string + * @param {Number} number The number to be formatted + * @returns {String} The formatted number e.g. 2500 -> 2,500 + */ +export const formatNumberWithCommas = (number) => { + return number.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ','); +};