Skip to content

Commit

Permalink
Applying comments v2
Browse files Browse the repository at this point in the history
  • Loading branch information
pfinashx committed Mar 14, 2022
1 parent d13d518 commit 80249a4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
10 changes: 10 additions & 0 deletions external/anomaly/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def ote_templates_root_dir_fx():
logger.debug(f'overloaded ote_templates_root_dir_fx: return {root}')
return root

@pytest.fixture(scope='session')
def ote_reference_root_dir_fx():
import os.path as osp
import logging
logger = logging.getLogger(__name__)
root = osp.dirname(osp.dirname(osp.realpath(__file__)))
root = f'{root}/tests/reference/'
logger.debug(f'overloaded ote_reference_root_dir_fx: return {root}')
return root

# pytest magic
def pytest_generate_tests(metafunc):
ote_pytest_generate_tests_insertion(metafunc)
Expand Down
37 changes: 34 additions & 3 deletions external/anomaly/tests/test_ote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
OTETestNNCFAction,
OTETestNNCFEvaluationAction,
OTETestNNCFExportAction,
OTETestNNCFExportEvaluationAction)
OTETestNNCFExportEvaluationAction,
OTETestNNCFGraphAction)

logger = get_logger(__name__)

Expand Down Expand Up @@ -119,6 +120,7 @@ def get_anomaly_test_action_classes() -> List[Type[BaseOTETestAction]]:
OTETestNNCFEvaluationAction,
OTETestNNCFExportAction,
OTETestNNCFExportEvaluationAction,
OTETestNNCFGraphAction,
]


Expand Down Expand Up @@ -294,7 +296,8 @@ def get_list_of_tests(cls, usecase: Optional[str] = None):

@pytest.fixture
def params_factories_for_test_actions_fx(self, current_test_parameters_fx,
dataset_definitions_fx, template_paths_fx) -> Dict[str,Callable[[], Dict]]:
dataset_definitions_fx,ote_current_reference_dir_fx,
template_paths_fx) -> Dict[str,Callable[[], Dict]]:
logger.debug('params_factories_for_test_actions_fx: begin')

test_parameters = deepcopy(current_test_parameters_fx)
Expand Down Expand Up @@ -323,8 +326,36 @@ def _training_params_factory() -> Dict:
'patience': patience,
'batch_size': batch_size,
}

def _nncf_graph_params_factory() -> Dict:
if dataset_definitions is None:
pytest.skip('The parameter "--dataset-definitions" is not set')

model_name = test_parameters['model_name']
dataset_name = test_parameters['dataset_name']

dataset_params = _get_dataset_params_from_dataset_definitions(dataset_definitions, dataset_name)

if model_name not in template_paths:
raise ValueError(f'Model {model_name} is absent in template_paths, '
f'template_paths.keys={list(template_paths.keys())}')
template_path = make_path_be_abs(template_paths[model_name], template_paths[ROOT_PATH_KEY])

logger.debug('training params factory: Before creating dataset and labels_schema')
dataset, labels_schema = _create_anomaly_classification_dataset_and_labels_schema(dataset_params, dataset_name)
logger.debug('training params factory: After creating dataset and labels_schema')

return {
'dataset': dataset,
'labels_schema': labels_schema,
'template_path': template_path,
'reference_dir': ote_current_reference_dir_fx,
'fn_get_compressed_model': None #NNCF not yet implemented in Anomaly
}

params_factories_for_test_actions = {
'training': _training_params_factory
'training': _training_params_factory,
'nncf_graph': _nncf_graph_params_factory,
}
logger.debug('params_factories_for_test_actions_fx: end')
return params_factories_for_test_actions
Expand Down

0 comments on commit 80249a4

Please sign in to comment.