diff --git a/external/anomaly/tests/conftest.py b/external/anomaly/tests/conftest.py index 0e4d54919f2..64a8d7a2321 100644 --- a/external/anomaly/tests/conftest.py +++ b/external/anomaly/tests/conftest.py @@ -41,6 +41,16 @@ def ote_templates_root_dir_fx(): logger.debug(f'overloaded ote_templates_root_dir_fx: return {root}') return root +@pytest.fixture(scope='session') +def ote_reference_root_dir_fx(): + import os.path as osp + import logging + logger = logging.getLogger(__name__) + root = osp.dirname(osp.dirname(osp.realpath(__file__))) + root = f'{root}/tests/reference/' + logger.debug(f'overloaded ote_reference_root_dir_fx: return {root}') + return root + # pytest magic def pytest_generate_tests(metafunc): ote_pytest_generate_tests_insertion(metafunc) diff --git a/external/anomaly/tests/test_ote_training.py b/external/anomaly/tests/test_ote_training.py index 1859523ebdc..19d2d7a59fa 100644 --- a/external/anomaly/tests/test_ote_training.py +++ b/external/anomaly/tests/test_ote_training.py @@ -58,7 +58,8 @@ OTETestNNCFAction, OTETestNNCFEvaluationAction, OTETestNNCFExportAction, - OTETestNNCFExportEvaluationAction) + OTETestNNCFExportEvaluationAction, + OTETestNNCFGraphAction) logger = get_logger(__name__) @@ -119,6 +120,7 @@ def get_anomaly_test_action_classes() -> List[Type[BaseOTETestAction]]: OTETestNNCFEvaluationAction, OTETestNNCFExportAction, OTETestNNCFExportEvaluationAction, + OTETestNNCFGraphAction, ] @@ -294,7 +296,8 @@ def get_list_of_tests(cls, usecase: Optional[str] = None): @pytest.fixture def params_factories_for_test_actions_fx(self, current_test_parameters_fx, - dataset_definitions_fx, template_paths_fx) -> Dict[str,Callable[[], Dict]]: + dataset_definitions_fx,ote_current_reference_dir_fx, + template_paths_fx) -> Dict[str,Callable[[], Dict]]: logger.debug('params_factories_for_test_actions_fx: begin') test_parameters = deepcopy(current_test_parameters_fx) @@ -323,8 +326,36 @@ def _training_params_factory() -> Dict: 'patience': patience, 'batch_size': batch_size, } + + def _nncf_graph_params_factory() -> Dict: + if dataset_definitions is None: + pytest.skip('The parameter "--dataset-definitions" is not set') + + model_name = test_parameters['model_name'] + dataset_name = test_parameters['dataset_name'] + + dataset_params = _get_dataset_params_from_dataset_definitions(dataset_definitions, dataset_name) + + if model_name not in template_paths: + raise ValueError(f'Model {model_name} is absent in template_paths, ' + f'template_paths.keys={list(template_paths.keys())}') + template_path = make_path_be_abs(template_paths[model_name], template_paths[ROOT_PATH_KEY]) + + logger.debug('training params factory: Before creating dataset and labels_schema') + dataset, labels_schema = _create_anomaly_classification_dataset_and_labels_schema(dataset_params, dataset_name) + logger.debug('training params factory: After creating dataset and labels_schema') + + return { + 'dataset': dataset, + 'labels_schema': labels_schema, + 'template_path': template_path, + 'reference_dir': ote_current_reference_dir_fx, + 'fn_get_compressed_model': None #NNCF not yet implemented in Anomaly + } + params_factories_for_test_actions = { - 'training': _training_params_factory + 'training': _training_params_factory, + 'nncf_graph': _nncf_graph_params_factory, } logger.debug('params_factories_for_test_actions_fx: end') return params_factories_for_test_actions