From 654e5cbc84201005968daeaf045d201719c40db7 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Fri, 19 May 2023 10:48:16 +0800 Subject: [PATCH 01/13] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 52c1193..c193390 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# deep_vital +# Deep Vital Pytorch version of [non-invasive-bp-estimation-using-deep-learning](https://github.com/Fabian-Sc85/non-invasive-bp-estimation-using-deep-learning) +pretrained model can be found from [zendo](https://zenodo.org/record/7948098) + From c3ef27805b34aee3aae2239176d4e5ba0479c939 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Fri, 19 May 2023 10:49:21 +0800 Subject: [PATCH 02/13] Update resnet50_1xb16_bp.py --- configs/resnet50_1xb16_bp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/resnet50_1xb16_bp.py b/configs/resnet50_1xb16_bp.py index 2afb950..1541c8d 100644 --- a/configs/resnet50_1xb16_bp.py +++ b/configs/resnet50_1xb16_bp.py @@ -8,7 +8,7 @@ frozen_stages=4, init_cfg=dict( type='Pretrained', - checkpoint='data/resnet_ppg_nonmixed_backbone_v2.pb')), + checkpoint='data/resnet_ppg_nonmixed_backbone.pth')), neck=dict(type='AveragePooling'), head=dict(type='BPDenseHead', loss=dict(type='MSELoss'))) From b5207529623ee7c62e08c9f7c0c74dd107d055a6 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Fri, 19 May 2023 23:26:50 +0800 Subject: [PATCH 03/13] Create python-publish.yml --- .github/workflows/python-publish.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .github/workflows/python-publish.yml diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..6ea2726 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,28 @@ +name: deploy + +on: push + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-n-publish: + runs-on: ubuntu-latest + if: startsWith(github.event.ref, 'refs/tags') + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + - name: Install torch + run: pip install torch + - name: Install wheel + run: pip install wheel + - name: Build Deep Vital + run: python setup.py sdist bdist_wheel + - name: Publish distribution to PyPI + run: | + pip install twine + twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }} From 0ba22535edf4e639437cfcd7adb35829dcb2f038 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Fri, 19 May 2023 16:09:30 +0000 Subject: [PATCH 04/13] [feature] update format --- README.md | 2 ++ configs/resnet/README.md | 23 +++++++++++++++++++ configs/{ => resnet}/resnet50_1xb16_bp.py | 2 +- deep_vital/models/__init__.py | 2 +- deep_vital/models/blood_pressure/__init__.py | 1 + .../models/{ => blood_pressure}/bp_resnet.py | 0 6 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 configs/resnet/README.md rename configs/{ => resnet}/resnet50_1xb16_bp.py (98%) create mode 100644 deep_vital/models/blood_pressure/__init__.py rename deep_vital/models/{ => blood_pressure}/bp_resnet.py (100%) diff --git a/README.md b/README.md index c193390..5474e2e 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,5 @@ Pytorch version of [non-invasive-bp-estimation-using-deep-learning](https://gith pretrained model can be found from [zendo](https://zenodo.org/record/7948098) +## Models +- [ResNet](configs/resnet/README.md) \ No newline at end of file diff --git a/configs/resnet/README.md b/configs/resnet/README.md new file mode 100644 index 0000000..877e2c3 --- /dev/null +++ b/configs/resnet/README.md @@ -0,0 +1,23 @@ +# ResNet Blood Pressure Model + +> [Assessment of Non-Invasive Blood Pressure Prediction from PPG and rPPG Signals Using Deep Learning](https://readpaper.com/paper/3198128029) + +## Abstract + +## Results and Models +| Model | MAE-SBP | MAE-DBP | +| :---: | :-----: | :-----: | +| ResNet501D | - | - | + + +## Citation + +```latex +@inproceedings{schrumpf2021assessment, + title={Assessment of deep learning based blood pressure prediction from PPG and rPPG signals}, + author={Schrumpf, Fabian and Frenzel, Patrick and Aust, Christoph and Osterhoff, Georg and Fuchs, Mirco}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={3820--3830}, + year={2021} +} +``` diff --git a/configs/resnet50_1xb16_bp.py b/configs/resnet/resnet50_1xb16_bp.py similarity index 98% rename from configs/resnet50_1xb16_bp.py rename to configs/resnet/resnet50_1xb16_bp.py index 1541c8d..d554245 100644 --- a/configs/resnet50_1xb16_bp.py +++ b/configs/resnet/resnet50_1xb16_bp.py @@ -1,4 +1,4 @@ -_base_ = ['_base_/default_runtime.py'] +_base_ = ['../_base_/default_runtime.py'] model = dict(type='BPResNet1D', data_preprocessor=dict(type='RppgDataPreprocessor'), diff --git a/deep_vital/models/__init__.py b/deep_vital/models/__init__.py index bc28491..9db2068 100644 --- a/deep_vital/models/__init__.py +++ b/deep_vital/models/__init__.py @@ -1,4 +1,4 @@ -from .bp_resnet import BPResNet1D +from .blood_pressure import * from .backbones import * from .necks import * from .heads import * diff --git a/deep_vital/models/blood_pressure/__init__.py b/deep_vital/models/blood_pressure/__init__.py new file mode 100644 index 0000000..08b3731 --- /dev/null +++ b/deep_vital/models/blood_pressure/__init__.py @@ -0,0 +1 @@ +from .bp_resnet import BPResNet1D \ No newline at end of file diff --git a/deep_vital/models/bp_resnet.py b/deep_vital/models/blood_pressure/bp_resnet.py similarity index 100% rename from deep_vital/models/bp_resnet.py rename to deep_vital/models/blood_pressure/bp_resnet.py From 50bfac2c2b16366aac8a7439c49ffe50dca04384 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Sat, 20 May 2023 16:27:25 +0000 Subject: [PATCH 05/13] [feature] using used_idx_file load data --- configs/resnet/resnet50_1xb16_bp.py | 16 +++++++++++++--- deep_vital/datasets/ppg_signal_dataset.py | 12 ++++++------ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/configs/resnet/resnet50_1xb16_bp.py b/configs/resnet/resnet50_1xb16_bp.py index d554245..a8ca11b 100644 --- a/configs/resnet/resnet50_1xb16_bp.py +++ b/configs/resnet/resnet50_1xb16_bp.py @@ -24,7 +24,7 @@ num_workers=2, dataset=dict(type=dataset_type, ann_file='data/rPPG-BP-UKL_rppg_7s.h5', - used_subjects=[7., 8., 10., 13., 14., 19., 21., 23., 24., 27., 33., 35., 36., 44., 46., 48.], + used_idx_file='data/train.txt', data_prefix='', test_mode=False, pipeline=train_pipeline), @@ -35,14 +35,24 @@ num_workers=2, dataset=dict(type=dataset_type, ann_file='data/rPPG-BP-UKL_rppg_7s.h5', - used_subjects=[6], + used_idx_file='data/val.txt', data_prefix='', test_mode=True, pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), ) val_evaluator = dict(type='MAE', gt_key='gt_label', pred_key='pred_label') -test_dataloader = val_dataloader +test_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict(type=dataset_type, + ann_file='data/rPPG-BP-UKL_rppg_7s.h5', + used_idx_file='data/test.txt', + data_prefix='', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) test_evaluator = val_evaluator # optimizer diff --git a/deep_vital/datasets/ppg_signal_dataset.py b/deep_vital/datasets/ppg_signal_dataset.py index 9d63a7f..92b6dbe 100644 --- a/deep_vital/datasets/ppg_signal_dataset.py +++ b/deep_vital/datasets/ppg_signal_dataset.py @@ -20,7 +20,7 @@ def expanduser(path): class RppgData(BaseDataset): def __init__(self, ann_file, - used_subjects=(), + used_idx_file=None, metainfo=None, data_root='', data_prefix='', @@ -34,7 +34,7 @@ def __init__(self, self.label = None self.rppg = None self.subject_idx = None - self.used_subjects=used_subjects + self.used_idx_file=used_idx_file self.num = None if isinstance(data_prefix, str): @@ -70,14 +70,14 @@ def load_data_list(self): self.subject_idx = np.array(data.get('/subject_idx'), dtype=int)[0, :] subjects_list = np.unique(self.subject_idx) - if self.used_subjects: - idx_used = np.where(np.isin(self.subject_idx, self.used_subjects))[-1] + if self.used_idx_file and os.path.exists(self.used_idx_file): + idx_used = np.loadtext(self.used_idx_file, delimiter=',') self.label = self.label[idx_used] self.rppg = self.rppg[idx_used] self.subject_idx = self.subject_idx[idx_used] - + self.num = self.subject_idx.shape[0] - + data_list = [] for _label, _rppg, _subject_idx in zip(self.label, self.rppg, self.subject_idx): # sbp_label = _label[0] From 1676926c198db792ba09660b4307d9609f46f224 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Sun, 21 May 2023 06:41:16 +0000 Subject: [PATCH 06/13] [feature] update ppg data --- configs/resnet/resnet50_1xb16_bp_ppg.py | 82 +++++++++++++++++++ ..._1xb16_bp.py => resnet50_1xb16_bp_rppg.py} | 7 +- deep_vital/datasets/__init__.py | 3 +- deep_vital/datasets/ppg_signal_dataset.py | 19 +++-- deep_vital/datasets/rppg_signal_dataset.py | 8 ++ 5 files changed, 108 insertions(+), 11 deletions(-) create mode 100644 configs/resnet/resnet50_1xb16_bp_ppg.py rename configs/resnet/{resnet50_1xb16_bp.py => resnet50_1xb16_bp_rppg.py} (89%) create mode 100644 deep_vital/datasets/rppg_signal_dataset.py diff --git a/configs/resnet/resnet50_1xb16_bp_ppg.py b/configs/resnet/resnet50_1xb16_bp_ppg.py new file mode 100644 index 0000000..4ee0a2d --- /dev/null +++ b/configs/resnet/resnet50_1xb16_bp_ppg.py @@ -0,0 +1,82 @@ +_base_ = ['../_base_/default_runtime.py'] + +model = dict(type='BPResNet1D', + data_preprocessor=dict(type='RppgDataPreprocessor'), + backbone=dict( + type='ResNet1D', + depth=50, + frozen_stages=4, + init_cfg=dict( + type='Pretrained', + checkpoint='data/resnet_ppg_nonmixed_backbone.pth')), + neck=dict(type='AveragePooling'), + head=dict(type='BPDenseHead', loss=dict(type='MSELoss'))) + +dataset_type = 'PpgData' +train_pipeline = [ + dict(type='PackInputs', input_key='ppg'), +] +test_pipeline = [ + dict(type='PackInputs', input_key='ppg'), +] +train_dataloader = dict( + batch_size=32, + num_workers=2, + dataset=dict(type=dataset_type, + ann_file='data/MIMIC-III_ppg_dataset.h5', + used_idx_file='data/train.txt', + data_prefix='', + test_mode=False, + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), +) +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict(type=dataset_type, + ann_file='data/MIMIC-III_ppg_dataset.h5', + used_idx_file='data/val.txt', + data_prefix='', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +val_evaluator = dict(type='MAE', gt_key='gt_label', pred_key='pred_label') +test_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict(type=dataset_type, + ann_file='data/MIMIC-III_ppg_dataset.h5', + used_idx_file='data/test.txt', + data_prefix='', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = val_evaluator + +# optimizer +optim_wrapper = dict(optimizer=dict(type='Adam', + lr=0.001, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0, + amsgrad=False)) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +visualizer = dict( + type='UniversalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + dict(type='TensorboardVisBackend') + ], +) + +default_hooks = dict( + checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=True, save_best='loss', rule='less'), +) +custom_hooks = [dict(type='EarlyStoppingHook', monitor='loss', rule='less', min_delta=0.01, strict=False, check_finite=True, patience=5)] \ No newline at end of file diff --git a/configs/resnet/resnet50_1xb16_bp.py b/configs/resnet/resnet50_1xb16_bp_rppg.py similarity index 89% rename from configs/resnet/resnet50_1xb16_bp.py rename to configs/resnet/resnet50_1xb16_bp_rppg.py index a8ca11b..30fee5b 100644 --- a/configs/resnet/resnet50_1xb16_bp.py +++ b/configs/resnet/resnet50_1xb16_bp_rppg.py @@ -20,7 +20,7 @@ dict(type='PackInputs', input_key='rppg'), ] train_dataloader = dict( - batch_size=16, + batch_size=32, num_workers=2, dataset=dict(type=dataset_type, ann_file='data/rPPG-BP-UKL_rppg_7s.h5', @@ -75,3 +75,8 @@ dict(type='TensorboardVisBackend') ], ) + +default_hooks = dict( + checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=True, save_best='loss', rule='less'), +) +custom_hooks = [dict(type='EarlyStoppingHook', monitor='loss', rule='less', min_delta=0.01, strict=False, check_finite=True, patience=5)] \ No newline at end of file diff --git a/deep_vital/datasets/__init__.py b/deep_vital/datasets/__init__.py index c7de3cd..8c14c25 100644 --- a/deep_vital/datasets/__init__.py +++ b/deep_vital/datasets/__init__.py @@ -1,2 +1,3 @@ -from .ppg_signal_dataset import RppgData +from .ppg_signal_dataset import PpgData +from .rppg_signal_dataset import RppgData from .transforms import * \ No newline at end of file diff --git a/deep_vital/datasets/ppg_signal_dataset.py b/deep_vital/datasets/ppg_signal_dataset.py index 92b6dbe..5e4f46d 100644 --- a/deep_vital/datasets/ppg_signal_dataset.py +++ b/deep_vital/datasets/ppg_signal_dataset.py @@ -17,7 +17,8 @@ def expanduser(path): @DATASETS.register_module() -class RppgData(BaseDataset): +class PpgData(BaseDataset): + _signal_name = 'ppg' def __init__(self, ann_file, used_idx_file=None, @@ -32,7 +33,7 @@ def __init__(self, lazy_init=False, max_refetch=1000): self.label = None - self.rppg = None + self.ppg = None self.subject_idx = None self.used_idx_file=used_idx_file self.num = None @@ -65,24 +66,24 @@ def load_data_list(self): assert os.path.exists(self.ann_file), self.ann_file data = h5py.File(self.ann_file, 'r') - self.label = np.array(data.get('/label')).T - self.rppg = np.array(data.get('/rppg')).T - self.subject_idx = np.array(data.get('/subject_idx'), dtype=int)[0, :] + self.label = np.array(data.get('/label')) + self.ppg = np.array(data.get(f'/{self.signal_name}')) + self.subject_idx = np.array(data.get('/subject_idx'), dtype=int) subjects_list = np.unique(self.subject_idx) if self.used_idx_file and os.path.exists(self.used_idx_file): - idx_used = np.loadtext(self.used_idx_file, delimiter=',') + idx_used = np.loadtxt(self.used_idx_file, dtype=np.int64, delimiter=',') self.label = self.label[idx_used] - self.rppg = self.rppg[idx_used] + self.ppg = self.ppg[idx_used] self.subject_idx = self.subject_idx[idx_used] self.num = self.subject_idx.shape[0] data_list = [] - for _label, _rppg, _subject_idx in zip(self.label, self.rppg, self.subject_idx): + for _label, _ppg, _subject_idx in zip(self.label, self.ppg, self.subject_idx): # sbp_label = _label[0] # dbp_label = _label[1] # _label = _label.reshape((1, 2, 1)) - info = {'gt_label': _label, 'rppg': _rppg, 'subject_idx': _subject_idx} + info = {'gt_label': _label, 'ppg': _ppg, 'subject_idx': _subject_idx} data_list.append(info) return data_list diff --git a/deep_vital/datasets/rppg_signal_dataset.py b/deep_vital/datasets/rppg_signal_dataset.py new file mode 100644 index 0000000..b85b6a8 --- /dev/null +++ b/deep_vital/datasets/rppg_signal_dataset.py @@ -0,0 +1,8 @@ + +from deep_vital.registry import DATASETS +from .ppg_signal_dataset import PpgData + + +@DATASETS.register_module() +class RppgData(PpgData): + _signal_name = 'rppg' \ No newline at end of file From 6f00d3bb5aef3a9d58210f4df5ca8dac3e1fce56 Mon Sep 17 00:00:00 2001 From: wang-tf Date: Sun, 21 May 2023 15:03:36 +0800 Subject: [PATCH 07/13] [feature] add test; add split data --- configs/resnet/resnet50_1xb16_bp_ppg.py | 16 +- configs/resnet/resnet50_1xb16_bp_rppg.py | 2 +- deep_vital/datasets/ppg_signal_dataset.py | 10 +- deep_vital/models/data_processors/__init__.py | 2 +- .../models/data_processors/data_processor.py | 4 +- tools/split_data_from_h5.py | 113 +++++++++++ tools/test.py | 185 ++++++++++++++++++ 7 files changed, 312 insertions(+), 20 deletions(-) create mode 100644 tools/split_data_from_h5.py create mode 100644 tools/test.py diff --git a/configs/resnet/resnet50_1xb16_bp_ppg.py b/configs/resnet/resnet50_1xb16_bp_ppg.py index 4ee0a2d..24db927 100644 --- a/configs/resnet/resnet50_1xb16_bp_ppg.py +++ b/configs/resnet/resnet50_1xb16_bp_ppg.py @@ -1,14 +1,11 @@ _base_ = ['../_base_/default_runtime.py'] model = dict(type='BPResNet1D', - data_preprocessor=dict(type='RppgDataPreprocessor'), + data_preprocessor=dict(type='DataPreprocessor'), backbone=dict( type='ResNet1D', depth=50, - frozen_stages=4, - init_cfg=dict( - type='Pretrained', - checkpoint='data/resnet_ppg_nonmixed_backbone.pth')), + ), neck=dict(type='AveragePooling'), head=dict(type='BPDenseHead', loss=dict(type='MSELoss'))) @@ -23,8 +20,7 @@ batch_size=32, num_workers=2, dataset=dict(type=dataset_type, - ann_file='data/MIMIC-III_ppg_dataset.h5', - used_idx_file='data/train.txt', + ann_file='data/mimic-iii_data/train.h5', data_prefix='', test_mode=False, pipeline=train_pipeline), @@ -34,8 +30,7 @@ batch_size=16, num_workers=2, dataset=dict(type=dataset_type, - ann_file='data/MIMIC-III_ppg_dataset.h5', - used_idx_file='data/val.txt', + ann_file='data/mimic-iii_data/val.h5', data_prefix='', test_mode=True, pipeline=test_pipeline), @@ -46,8 +41,7 @@ batch_size=16, num_workers=2, dataset=dict(type=dataset_type, - ann_file='data/MIMIC-III_ppg_dataset.h5', - used_idx_file='data/test.txt', + ann_file='data/mimic-iii_data/test.h5', data_prefix='', test_mode=True, pipeline=test_pipeline), diff --git a/configs/resnet/resnet50_1xb16_bp_rppg.py b/configs/resnet/resnet50_1xb16_bp_rppg.py index 30fee5b..23e3782 100644 --- a/configs/resnet/resnet50_1xb16_bp_rppg.py +++ b/configs/resnet/resnet50_1xb16_bp_rppg.py @@ -1,7 +1,7 @@ _base_ = ['../_base_/default_runtime.py'] model = dict(type='BPResNet1D', - data_preprocessor=dict(type='RppgDataPreprocessor'), + data_preprocessor=dict(type='DataPreprocessor'), backbone=dict( type='ResNet1D', depth=50, diff --git a/deep_vital/datasets/ppg_signal_dataset.py b/deep_vital/datasets/ppg_signal_dataset.py index 5e4f46d..42f4625 100644 --- a/deep_vital/datasets/ppg_signal_dataset.py +++ b/deep_vital/datasets/ppg_signal_dataset.py @@ -33,7 +33,7 @@ def __init__(self, lazy_init=False, max_refetch=1000): self.label = None - self.ppg = None + self.signal = None self.subject_idx = None self.used_idx_file=used_idx_file self.num = None @@ -67,23 +67,23 @@ def load_data_list(self): data = h5py.File(self.ann_file, 'r') self.label = np.array(data.get('/label')) - self.ppg = np.array(data.get(f'/{self.signal_name}')) + self.signal = np.array(data.get(f'/{self.signal_name}')) self.subject_idx = np.array(data.get('/subject_idx'), dtype=int) subjects_list = np.unique(self.subject_idx) if self.used_idx_file and os.path.exists(self.used_idx_file): idx_used = np.loadtxt(self.used_idx_file, dtype=np.int64, delimiter=',') self.label = self.label[idx_used] - self.ppg = self.ppg[idx_used] + self.signal = self.signal[idx_used] self.subject_idx = self.subject_idx[idx_used] self.num = self.subject_idx.shape[0] data_list = [] - for _label, _ppg, _subject_idx in zip(self.label, self.ppg, self.subject_idx): + for _label, _signal, _subject_idx in zip(self.label, self.signal, self.subject_idx): # sbp_label = _label[0] # dbp_label = _label[1] # _label = _label.reshape((1, 2, 1)) - info = {'gt_label': _label, 'ppg': _ppg, 'subject_idx': _subject_idx} + info = {'gt_label': _label, f'{self._signal_name}': _signal, 'subject_idx': _subject_idx} data_list.append(info) return data_list diff --git a/deep_vital/models/data_processors/__init__.py b/deep_vital/models/data_processors/__init__.py index c7ab907..e85b8f7 100644 --- a/deep_vital/models/data_processors/__init__.py +++ b/deep_vital/models/data_processors/__init__.py @@ -1 +1 @@ -from .data_processor import RppgDataPreprocessor \ No newline at end of file +from .data_processor import DataPreprocessor \ No newline at end of file diff --git a/deep_vital/models/data_processors/data_processor.py b/deep_vital/models/data_processors/data_processor.py index 326a6b9..09731c2 100644 --- a/deep_vital/models/data_processors/data_processor.py +++ b/deep_vital/models/data_processors/data_processor.py @@ -9,7 +9,7 @@ @MODELS.register_module() -class RppgDataPreprocessor(BaseDataPreprocessor): +class DataPreprocessor(BaseDataPreprocessor): def __init__(self, pad_size_divisor=1, pad_value: Union[float, int] = 0, non_blocking: Optional[bool] = False): super().__init__(non_blocking) self.pad_size_divisor = pad_size_divisor @@ -28,7 +28,7 @@ def forward(self, data: dict, training:bool =False): self.pad_value) elif isinstance(_batch_inputs, torch.Tensor): assert _batch_inputs.dim() == 3, ( - 'The input of `RppgDataPreprocessor` should be a NCL tensor ' + 'The input of `DataPreprocessor` should be a NCL tensor ' 'or a list of tensor, but got a tensor with shape: ' f'{_batch_inputs.shape}') _batch_inputs = _batch_inputs.float() diff --git a/tools/split_data_from_h5.py b/tools/split_data_from_h5.py new file mode 100644 index 0000000..8d34234 --- /dev/null +++ b/tools/split_data_from_h5.py @@ -0,0 +1,113 @@ +import os +import h5py +import numpy as np +import argparse +from sklearn.model_selection import train_test_split + + +def split_index(data_file, save_dir, train_num: int, val_num: int, test_num: int, divide_by_subject: bool, sig_name='ppg'): + assert os.path.exists(data_file), data_file + train_num = int(train_num) + val_num = int(val_num) + test_num = int(test_num) + + with h5py.File(data_file, 'r') as f: + signal = np.array(f.get(f'/{sig_name}')) + BP = np.array(f.get('/label')) + # BP = np.round(BP) + # BP = np.transpose(BP) + subject_idx = np.squeeze(np.array(f.get('/subject_idx'))) + N_samp_total = BP.shape[0] + subject_idx = subject_idx[:N_samp_total] + print(f'load data samples {N_samp_total}') + + # Divide the dataset into training, validation and test set + # ------------------------------------------------------------------------------- + if divide_by_subject is True: + valid_idx = np.arange(subject_idx.shape[-1]) + + # divide the subjects into training, validation and test subjects + subject_labels = np.unique(subject_idx) + subjects_train_labels, subjects_val_labels = train_test_split(subject_labels, test_size=0.5) + subjects_val_labels, subjects_test_labels = train_test_split(subjects_val_labels, test_size=0.5) + + # Calculate samples belong to training, validation and test subjects + train_part = valid_idx[np.isin(subject_idx,subjects_train_labels)] + val_part = valid_idx[np.isin(subject_idx,subjects_val_labels)] + test_part = valid_idx[np.isin(subject_idx, subjects_test_labels)] + + # draw a number samples defined by N_train, N_val and N_test from the training, validation and test subjects + idx_train = np.random.choice(train_part, train_num, replace=False) + idx_val = np.random.choice(val_part, val_num, replace=False) + idx_test = np.random.choice(test_part, test_num, replace=False) + else: + # Create a subset of the whole dataset by drawing a number of subjects from the dataset. The total number of + # samples contributed by those subjects must equal N_train + N_val + _N_test + subject_labels, SampSubject_hist = np.unique(subject_idx, return_counts=True) + cumsum_samp = np.cumsum(SampSubject_hist) + subject_labels_train = subject_labels[:np.nonzero(cumsum_samp>(train_num+val_num+test_num))[0][0]] + idx_valid = np.nonzero(np.isin(subject_idx,subject_labels_train))[0] + + # divide subset randomly into training, validation and test set + idx_train, idx_val = train_test_split(idx_valid, train_size= train_num, test_size=val_num+test_num) + idx_val, idx_test = train_test_split(idx_val, test_size=0.5) + print(f'train data num: {len(idx_train)}') + print(f'val data num: {len(idx_val)}') + print(f'test data num: {len(idx_test)}') + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + train_index_save_path = os.path.join(save_dir, 'train_index.txt') + val_index_save_path = os.path.join(save_dir, 'val_index.txt') + test_index_save_path = os.path.join(save_dir, 'test_index.txt') + np.savetxt(train_index_save_path, idx_train, delimiter=',', fmt='%d') + np.savetxt(val_index_save_path, idx_val, delimiter=',', fmt='%d') + np.savetxt(test_index_save_path, idx_test, delimiter=',', fmt='%d') + + train_data_save_path = os.path.join(save_dir, 'train.h5') + val_data_save_path = os.path.join(save_dir, 'val.h5') + test_data_save_path = os.path.join(save_dir, 'test.h5') + with h5py.File(train_data_save_path, 'w') as f: + f.create_dataset(sig_name, data=signal[idx_train]) + f.create_dataset('label', data=BP[idx_train]) + f.create_dataset('subject_idx', data=subject_idx[idx_train]) + with h5py.File(val_data_save_path, 'w') as f: + f.create_dataset(sig_name, data=signal[idx_val]) + f.create_dataset('label', data=BP[idx_val]) + f.create_dataset('subject_idx', data=subject_idx[idx_val]) + with h5py.File(test_data_save_path, 'w') as f: + f.create_dataset(sig_name, data=signal[idx_test]) + f.create_dataset('label', data=BP[idx_test]) + f.create_dataset('subject_idx', data=subject_idx[idx_test]) + + return + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('data_file') + parser.add_argument('save_dir') + parser.add_argument('--train_num', default=1e6) + parser.add_argument('--val_num', default=2.5e5) + parser.add_argument('--test_num', default=2.5e5) + parser.add_argument('--divide', default=True) + parser.add_argument('--sig_name', default='ppg') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + print(args) + data_file = args.data_file + save_dir = args.save_dir + train_num = args.train_num + val_num = args.val_num + test_num = args.test_num + divide_by_subject=args.divide + sig_name = args.sig_name + split_index(data_file, save_dir, train_num, val_num, test_num, divide_by_subject, sig_name) + + +if __name__ == '__main__': + main() diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000..e24db6b --- /dev/null +++ b/tools/test.py @@ -0,0 +1,185 @@ +import argparse +import os +import os.path as osp +from copy import deepcopy + +import mmengine +from mmengine.config import Config, ConfigDict, DictAction +from mmengine.evaluator import DumpResults +from mmengine.runner import Runner + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMPreTrain test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument('--out', help='the file to output results.') + parser.add_argument( + '--out-item', + choices=['metrics', 'pred'], + help='To output whether metrics or predictions. ' + 'Defaults to output predictions.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--amp', + action='store_true', + help='enable automatic-mixed-precision test') + parser.add_argument( + '--show-dir', + help='directory where the visualization images will be saved.') + parser.add_argument( + '--show', + action='store_true', + help='whether to display the prediction results in a window.') + parser.add_argument( + '--interval', + type=int, + default=1, + help='visualize per interval samples.') + parser.add_argument( + '--wait-time', + type=float, + default=2, + help='display time of every window. (second)') + parser.add_argument( + '--no-pin-memory', + action='store_true', + help='whether to disable the pin_memory option in dataloaders.') + parser.add_argument( + '--tta', + action='store_true', + help='Whether to enable the Test-Time-Aug (TTA). If the config file ' + 'has `tta_pipeline` and `tta_model` fields, use them to determine the ' + 'TTA transforms and how to merge the TTA results. Otherwise, use flip ' + 'TTA by averaging classification score.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def merge_args(cfg, args): + """Merge CLI arguments to config.""" + cfg.launcher = args.launcher + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + # enable automatic-mixed-precision test + if args.amp: + cfg.test_cfg.fp16 = True + + # -------------------- visualization -------------------- + if args.show or (args.show_dir is not None): + assert 'visualization' in cfg.default_hooks, \ + 'VisualizationHook is not set in the `default_hooks` field of ' \ + 'config. Please set `visualization=dict(type="VisualizationHook")`' + + cfg.default_hooks.visualization.enable = True + cfg.default_hooks.visualization.show = args.show + cfg.default_hooks.visualization.wait_time = args.wait_time + cfg.default_hooks.visualization.out_dir = args.show_dir + cfg.default_hooks.visualization.interval = args.interval + + # -------------------- TTA related args -------------------- + if args.tta: + if 'tta_model' not in cfg: + cfg.tta_model = dict(type='mmpretrain.AverageClsScoreTTA') + if 'tta_pipeline' not in cfg: + test_pipeline = cfg.test_dataloader.dataset.pipeline + cfg.tta_pipeline = deepcopy(test_pipeline) + flip_tta = dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='RandomFlip', prob=1.), + dict(type='RandomFlip', prob=0.) + ], + [test_pipeline[-1]], + ]) + cfg.tta_pipeline[-1] = flip_tta + cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model) + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + + # ----------------- Default dataloader args ----------------- + default_dataloader_cfg = ConfigDict( + pin_memory=True, + collate_fn=dict(type='default_collate'), + ) + + def set_default_dataloader_cfg(cfg, field): + if cfg.get(field, None) is None: + return + dataloader_cfg = deepcopy(default_dataloader_cfg) + dataloader_cfg.update(cfg[field]) + cfg[field] = dataloader_cfg + if args.no_pin_memory: + cfg[field]['pin_memory'] = False + + set_default_dataloader_cfg(cfg, 'test_dataloader') + + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + return cfg + + +def main(): + args = parse_args() + + if args.out is None and args.out_item is not None: + raise ValueError('Please use `--out` argument to specify the ' + 'path of the output file before using `--out-item`.') + + # load config + cfg = Config.fromfile(args.config) + + # merge cli arguments to config + cfg = merge_args(cfg, args) + + # build the runner from config + runner = Runner.from_cfg(cfg) + + if args.out and args.out_item in ['pred', None]: + runner.test_evaluator.metrics.append( + DumpResults(out_file_path=args.out)) + + # start testing + metrics = runner.test() + + if args.out and args.out_item == 'metrics': + mmengine.dump(metrics, args.out) + + +if __name__ == '__main__': + main() \ No newline at end of file From be561439d2387977a74798523ec8280218ab71a8 Mon Sep 17 00:00:00 2001 From: wang-tf Date: Sun, 21 May 2023 16:22:40 +0800 Subject: [PATCH 08/13] [feature] fix error --- deep_vital/datasets/ppg_signal_dataset.py | 2 +- tools/dist_train.sh | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 tools/dist_train.sh diff --git a/deep_vital/datasets/ppg_signal_dataset.py b/deep_vital/datasets/ppg_signal_dataset.py index 42f4625..0f4bb40 100644 --- a/deep_vital/datasets/ppg_signal_dataset.py +++ b/deep_vital/datasets/ppg_signal_dataset.py @@ -67,7 +67,7 @@ def load_data_list(self): data = h5py.File(self.ann_file, 'r') self.label = np.array(data.get('/label')) - self.signal = np.array(data.get(f'/{self.signal_name}')) + self.signal = np.array(data.get(f'/{self._signal_name}')) self.subject_idx = np.array(data.get('/subject_idx'), dtype=int) subjects_list = np.unique(self.subject_idx) diff --git a/tools/dist_train.sh b/tools/dist_train.sh new file mode 100644 index 0000000..1eb32aa --- /dev/null +++ b/tools/dist_train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} \ No newline at end of file From da62e204105973a95355468fa23b90585e32bf0c Mon Sep 17 00:00:00 2001 From: wang-tf Date: Mon, 22 May 2023 00:11:02 +0800 Subject: [PATCH 09/13] [feature] add val_loss --- deep_vital/evaluation/metrics/mae.py | 6 +++++- deep_vital/models/blood_pressure/bp_resnet.py | 13 +++++++++---- deep_vital/structures/data_sample.py | 4 ++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/deep_vital/evaluation/metrics/mae.py b/deep_vital/evaluation/metrics/mae.py index 33f1239..1f72779 100644 --- a/deep_vital/evaluation/metrics/mae.py +++ b/deep_vital/evaluation/metrics/mae.py @@ -58,7 +58,8 @@ def process(self, data_batch, data_samples): for data_sample in data_samples: result = { 'pred_label': data_sample['pred_label'], - 'gt_label': data_sample['gt_label'] + 'gt_label': data_sample['gt_label'], + 'pred_loss': data_sample['pred_loss'] } self.results.append(result) @@ -66,6 +67,9 @@ def compute_metrics(self, results): metrics = {} target = torch.stack([res['gt_label'] for res in results]) pred = torch.stack([res['pred_label'] for res in results]) + loss = torch.stack([res['pred_loss'] for res in results]) + val_loss = loss.mean() + metrics['val_loss'] = val_loss diff = abs(pred - target) sbp_mean = diff[:, 0].mean() diff --git a/deep_vital/models/blood_pressure/bp_resnet.py b/deep_vital/models/blood_pressure/bp_resnet.py index 856c702..baa6dd1 100644 --- a/deep_vital/models/blood_pressure/bp_resnet.py +++ b/deep_vital/models/blood_pressure/bp_resnet.py @@ -34,14 +34,15 @@ def with_neck(self) -> bool: return hasattr(self, 'neck') and self.neck is not None def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): - x = self.extract_feat(batch_inputs) + feats = self.extract_feat(batch_inputs) - results_list = self.head.predict(x, + results_list = self.head.predict(feats, batch_data_samples, rescale=rescale) + val_loss = self.head.loss(feats, batch_data_samples) results_list = torch.cat([results_list[0], results_list[1]], dim=1) batch_data_samples = self.add_pred_to_datasample( - batch_data_samples, results_list) + batch_data_samples, results_list, val_loss) return batch_data_samples def _forward(self, batch_inputs, batch_data_samples=None): @@ -80,9 +81,10 @@ def loss(self, inputs: torch.Tensor, dict[str, Tensor]: a dictionary of loss components """ feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) - def add_pred_to_datasample(self, data_samples, results_list): + def add_pred_to_datasample(self, data_samples, results_list, loss_list=None): """Add predictions to `DetDataSample`. Args: @@ -106,4 +108,7 @@ def add_pred_to_datasample(self, data_samples, results_list): """ for data_sample, pred_label in zip(data_samples, results_list): data_sample.pred_label = pred_label + if loss_list is not None: + for data_sample, pred_loss in zip(data_sample, loss_list): + data_sample.pred_loss = pred_loss return data_samples \ No newline at end of file diff --git a/deep_vital/structures/data_sample.py b/deep_vital/structures/data_sample.py index 9bd9aa1..56123c4 100644 --- a/deep_vital/structures/data_sample.py +++ b/deep_vital/structures/data_sample.py @@ -48,3 +48,7 @@ def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample': """Set ``pred_label``.""" self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor) return self + + def set_pred_loss(self, value: LABEL_TYPE) -> 'DataSample': + self.set_field(format_label(value), 'pred_loss', dtype=torch.Tensor) + return self From ab24a6cfae170fa3f58f60c426f067bd150f32b6 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Mon, 22 May 2023 02:07:31 +0000 Subject: [PATCH 10/13] [feature] add BlandAltmanPlot hook --- configs/resnet/resnet50_1xb16_bp_ppg.py | 28 ++++--- deep_vital/evaluation/metrics/__init__.py | 3 +- .../evaluation/metrics/bland_altman_plot.py | 73 +++++++++++++++++++ deep_vital/evaluation/metrics/mae.py | 58 +++------------ 4 files changed, 104 insertions(+), 58 deletions(-) create mode 100644 deep_vital/evaluation/metrics/bland_altman_plot.py diff --git a/configs/resnet/resnet50_1xb16_bp_ppg.py b/configs/resnet/resnet50_1xb16_bp_ppg.py index 24db927..ffc7607 100644 --- a/configs/resnet/resnet50_1xb16_bp_ppg.py +++ b/configs/resnet/resnet50_1xb16_bp_ppg.py @@ -5,7 +5,7 @@ backbone=dict( type='ResNet1D', depth=50, - ), + ), neck=dict(type='AveragePooling'), head=dict(type='BPDenseHead', loss=dict(type='MSELoss'))) @@ -17,8 +17,8 @@ dict(type='PackInputs', input_key='ppg'), ] train_dataloader = dict( - batch_size=32, - num_workers=2, + batch_size=128, + num_workers=4, dataset=dict(type=dataset_type, ann_file='data/mimic-iii_data/train.h5', data_prefix='', @@ -27,8 +27,8 @@ sampler=dict(type='DefaultSampler', shuffle=True), ) val_dataloader = dict( - batch_size=16, - num_workers=2, + batch_size=64, + num_workers=4, dataset=dict(type=dataset_type, ann_file='data/mimic-iii_data/val.h5', data_prefix='', @@ -36,10 +36,14 @@ pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), ) -val_evaluator = dict(type='MAE', gt_key='gt_label', pred_key='pred_label') +val_evaluator = [ + dict(type='MAE', gt_key='gt_label', + pred_key='pred_label', loss_key='pred_loss'), + dict(type='BlandAltmanPlot') +] test_dataloader = dict( - batch_size=16, - num_workers=2, + batch_size=64, + num_workers=4, dataset=dict(type=dataset_type, ann_file='data/mimic-iii_data/test.h5', data_prefix='', @@ -71,6 +75,10 @@ ) default_hooks = dict( - checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=True, save_best='loss', rule='less'), + checkpoint=dict(type='CheckpointHook', interval=1, + by_epoch=True, save_best='pred_loss', rule='less'), ) -custom_hooks = [dict(type='EarlyStoppingHook', monitor='loss', rule='less', min_delta=0.01, strict=False, check_finite=True, patience=5)] \ No newline at end of file +custom_hooks = [ + dict(type='EarlyStoppingHook', monitor='pred_loss', rule='less', + min_delta=0.01, strict=False, check_finite=True, patience=5) +] diff --git a/deep_vital/evaluation/metrics/__init__.py b/deep_vital/evaluation/metrics/__init__.py index 661e0df..b9bc067 100644 --- a/deep_vital/evaluation/metrics/__init__.py +++ b/deep_vital/evaluation/metrics/__init__.py @@ -1 +1,2 @@ -from .mae import MAE \ No newline at end of file +from .mae import MAE +from .bland_altman_plot import BlandAltmanPlot \ No newline at end of file diff --git a/deep_vital/evaluation/metrics/bland_altman_plot.py b/deep_vital/evaluation/metrics/bland_altman_plot.py new file mode 100644 index 0000000..408267a --- /dev/null +++ b/deep_vital/evaluation/metrics/bland_altman_plot.py @@ -0,0 +1,73 @@ +from typing import Optional +import torch +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg +import PIL.Image as Image +from mmengine.evaluator import BaseMetric +from mmengine.visualization import Visualizer +from mmengine import MessageHub +from deep_vital.registry import METRICS + +message_hub = MessageHub.get_current_instance() + + +def bland_altman_plot(data1, data2, *args, **kwargs): + data1 = np.asarray(data1.numpy()) + data2 = np.asarray(data2.numpy()) + mean = np.mean([data1, data2], axis=0) + diff = data1 - data2 # Difference between data1 and data2 + md = np.mean(diff) # Mean of the difference + sd = np.std(diff, axis=0) # Standard deviation of the difference + + plt.cla() + plt.scatter(mean, diff, *args, **kwargs) + plt.axhline(md, color='gray', linestyle='-') + plt.axhline(md + 1.96*sd, color='gray', linestyle='--') + plt.axhline(md - 1.96*sd, color='gray', linestyle='--') + canvas = FigureCanvasAgg(plt.gcf()) + canvas.draw() + w, h = canvas.get_width_height() + buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8) + buf.shape = (w, h, 4) + buf = np.roll(buf, 3, axis=2) + image = Image.frombytes('RGBA', (w, h), buf.tostring()) + image = np.asarray(image) + rgb_image = image[:, :, :3] + plt.close() + return rgb_image + + +@METRICS.register_module() +class BlandAltmanPlot(BaseMetric): + metric = 'BlandAltmanPlot' + + def __init__(self, + gt_key='gt_label', + pred_key='pred_label', + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.gt_key = gt_key + self.pred_key = pred_key + self.visualizer = Visualizer.get_current_instance() + + def process(self, data_batch, data_samples): + for data_sample in data_samples: + result = { + 'pred_label': data_sample[self.pred_key], + 'gt_label': data_sample[self.gt_key], + } + self.results.append(result) + + def compute_metrics(self, results): + metrics = {} + target = torch.stack([res['gt_label'] for res in results]) + pred = torch.stack([res['pred_label'] for res in results]) + + sbp_BAP = bland_altman_plot(pred[:, 0], target[:, 0]) + dbp_BAP = bland_altman_plot(pred[:, 1], target[:, 1]) + current_step = message_hub.get_info('epoch') + self.visualizer.add_image('sbp_BAP', sbp_BAP, current_step) + self.visualizer.add_image('dbp_BAP', dbp_BAP, current_step) + return metrics diff --git a/deep_vital/evaluation/metrics/mae.py b/deep_vital/evaluation/metrics/mae.py index 1f72779..0c38dd5 100644 --- a/deep_vital/evaluation/metrics/mae.py +++ b/deep_vital/evaluation/metrics/mae.py @@ -1,85 +1,49 @@ from typing import Optional import torch -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.backends.backend_agg import FigureCanvasAgg -import PIL.Image as Image from mmengine.evaluator import BaseMetric -from mmengine.visualization import Visualizer from mmengine import MessageHub from deep_vital.registry import METRICS message_hub = MessageHub.get_current_instance() -def bland_altman_plot(data1, data2, *args, **kwargs): - data1 = np.asarray(data1.numpy()) - data2 = np.asarray(data2.numpy()) - mean = np.mean([data1, data2], axis=0) - diff = data1 - data2 # Difference between data1 and data2 - md = np.mean(diff) # Mean of the difference - sd = np.std(diff, axis=0) # Standard deviation of the difference - - plt.cla() - plt.scatter(mean, diff, *args, **kwargs) - plt.axhline(md, color='gray', linestyle='-') - plt.axhline(md + 1.96*sd, color='gray', linestyle='--') - plt.axhline(md - 1.96*sd, color='gray', linestyle='--') - canvas = FigureCanvasAgg(plt.gcf()) - canvas.draw() - w, h = canvas.get_width_height() - buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8) - buf.shape = (w, h, 4) - buf = np.roll(buf, 3, axis=2) - image = Image.frombytes('RGBA', (w, h), buf.tostring()) - image = np.asarray(image) - rgb_image = image[:, :, :3] - plt.close() - return rgb_image - - @METRICS.register_module() class MAE(BaseMetric): metric = 'MAE' def __init__(self, gt_key='gt_label', - pred_key='pred', - mask_key=None, + pred_key='pred_label', + loss_key=None, collect_device: str = 'cpu', prefix: Optional[str] = None): super().__init__(collect_device=collect_device, prefix=prefix) self.gt_key = gt_key self.pred_key = pred_key - self.mask_key = mask_key - self.visualizer = Visualizer.get_current_instance() + self.loss_key = loss_key def process(self, data_batch, data_samples): for data_sample in data_samples: result = { - 'pred_label': data_sample['pred_label'], - 'gt_label': data_sample['gt_label'], - 'pred_loss': data_sample['pred_loss'] + 'pred_label': data_sample[self.pred_key], + 'gt_label': data_sample[self.gt_key] } + if self.loss_key: + result['pred_loss'] = data_sample[self.loss_key] self.results.append(result) def compute_metrics(self, results): metrics = {} target = torch.stack([res['gt_label'] for res in results]) pred = torch.stack([res['pred_label'] for res in results]) - loss = torch.stack([res['pred_loss'] for res in results]) - val_loss = loss.mean() - metrics['val_loss'] = val_loss + if self.loss_key: + loss = torch.stack([res['pred_loss'] for res in results]) + val_loss = loss.mean() + metrics['val_loss'] = val_loss diff = abs(pred - target) sbp_mean = diff[:, 0].mean() dbp_mean = diff[:, 1].mean() metrics['sbp_mae'] = sbp_mean metrics['dbp_mae'] = dbp_mean - - sbp_BAP = bland_altman_plot(pred[:, 0], target[:, 0]) - dbp_BAP = bland_altman_plot(pred[:, 1], target[:, 1]) - current_step = message_hub.get_info('epoch') - self.visualizer.add_image('sbp_BAP', sbp_BAP, current_step) - self.visualizer.add_image('dbp_BAP', dbp_BAP, current_step) return metrics From f4d0adeba9c88a140b76a89c7e443922a153abfd Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Mon, 22 May 2023 04:06:22 +0000 Subject: [PATCH 11/13] [bugfix] comput loss during mae --- ...b16_bp_ppg.py => resnet50_4xb128_bp_ppg.py} | 8 ++++---- deep_vital/engine/hooks/val_loss.py | 0 .../evaluation/metrics/bland_altman_plot.py | 1 + deep_vital/evaluation/metrics/mae.py | 18 +++++++++++++----- deep_vital/models/blood_pressure/bp_resnet.py | 8 ++++---- deep_vital/models/losses/mse_loss.py | 2 +- 6 files changed, 23 insertions(+), 14 deletions(-) rename configs/resnet/{resnet50_1xb16_bp_ppg.py => resnet50_4xb128_bp_ppg.py} (96%) create mode 100644 deep_vital/engine/hooks/val_loss.py diff --git a/configs/resnet/resnet50_1xb16_bp_ppg.py b/configs/resnet/resnet50_4xb128_bp_ppg.py similarity index 96% rename from configs/resnet/resnet50_1xb16_bp_ppg.py rename to configs/resnet/resnet50_4xb128_bp_ppg.py index ffc7607..8cd66ef 100644 --- a/configs/resnet/resnet50_1xb16_bp_ppg.py +++ b/configs/resnet/resnet50_4xb128_bp_ppg.py @@ -18,7 +18,7 @@ ] train_dataloader = dict( batch_size=128, - num_workers=4, + num_workers=2, dataset=dict(type=dataset_type, ann_file='data/mimic-iii_data/train.h5', data_prefix='', @@ -28,7 +28,7 @@ ) val_dataloader = dict( batch_size=64, - num_workers=4, + num_workers=2, dataset=dict(type=dataset_type, ann_file='data/mimic-iii_data/val.h5', data_prefix='', @@ -39,11 +39,11 @@ val_evaluator = [ dict(type='MAE', gt_key='gt_label', pred_key='pred_label', loss_key='pred_loss'), - dict(type='BlandAltmanPlot') + # dict(type='BlandAltmanPlot') ] test_dataloader = dict( batch_size=64, - num_workers=4, + num_workers=2, dataset=dict(type=dataset_type, ann_file='data/mimic-iii_data/test.h5', data_prefix='', diff --git a/deep_vital/engine/hooks/val_loss.py b/deep_vital/engine/hooks/val_loss.py new file mode 100644 index 0000000..e69de29 diff --git a/deep_vital/evaluation/metrics/bland_altman_plot.py b/deep_vital/evaluation/metrics/bland_altman_plot.py index 408267a..d025875 100644 --- a/deep_vital/evaluation/metrics/bland_altman_plot.py +++ b/deep_vital/evaluation/metrics/bland_altman_plot.py @@ -41,6 +41,7 @@ def bland_altman_plot(data1, data2, *args, **kwargs): @METRICS.register_module() class BlandAltmanPlot(BaseMetric): metric = 'BlandAltmanPlot' + default_prefix = 'BlandAltmanPlot' def __init__(self, gt_key='gt_label', diff --git a/deep_vital/evaluation/metrics/mae.py b/deep_vital/evaluation/metrics/mae.py index 0c38dd5..87a9426 100644 --- a/deep_vital/evaluation/metrics/mae.py +++ b/deep_vital/evaluation/metrics/mae.py @@ -1,5 +1,6 @@ from typing import Optional import torch +import torch.nn.functional as F from mmengine.evaluator import BaseMetric from mmengine import MessageHub from deep_vital.registry import METRICS @@ -10,6 +11,7 @@ @METRICS.register_module() class MAE(BaseMetric): metric = 'MAE' + default_prefix = 'MAE' def __init__(self, gt_key='gt_label', @@ -28,8 +30,8 @@ def process(self, data_batch, data_samples): 'pred_label': data_sample[self.pred_key], 'gt_label': data_sample[self.gt_key] } - if self.loss_key: - result['pred_loss'] = data_sample[self.loss_key] + # if self.loss_key: + # result['pred_loss'] = data_sample[self.loss_key] self.results.append(result) def compute_metrics(self, results): @@ -37,9 +39,15 @@ def compute_metrics(self, results): target = torch.stack([res['gt_label'] for res in results]) pred = torch.stack([res['pred_label'] for res in results]) if self.loss_key: - loss = torch.stack([res['pred_loss'] for res in results]) - val_loss = loss.mean() - metrics['val_loss'] = val_loss + # loss = torch.stack([res['pred_loss'] for res in results]) + # val_loss = loss.mean() + # metrics['pred_loss'] = val_loss + sbp_pred, dbp_pred = pred + sbp_target = target[:, 0][..., None] + dbp_target = target[:, 1][..., None] + sbp_mse_loss = F.mse_loss(sbp_pred, sbp_target.detach()) + dbp_mse_loss = F.mse_loss(dbp_pred, dbp_target.detach()) + metrics['pred_loss'] = sbp_mse_loss + dbp_mse_loss diff = abs(pred - target) sbp_mean = diff[:, 0].mean() diff --git a/deep_vital/models/blood_pressure/bp_resnet.py b/deep_vital/models/blood_pressure/bp_resnet.py index baa6dd1..0f1763c 100644 --- a/deep_vital/models/blood_pressure/bp_resnet.py +++ b/deep_vital/models/blood_pressure/bp_resnet.py @@ -39,10 +39,10 @@ def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): results_list = self.head.predict(feats, batch_data_samples, rescale=rescale) - val_loss = self.head.loss(feats, batch_data_samples) + # val_loss = self.head.loss(feats, batch_data_samples) results_list = torch.cat([results_list[0], results_list[1]], dim=1) batch_data_samples = self.add_pred_to_datasample( - batch_data_samples, results_list, val_loss) + batch_data_samples, results_list) return batch_data_samples def _forward(self, batch_inputs, batch_data_samples=None): @@ -109,6 +109,6 @@ def add_pred_to_datasample(self, data_samples, results_list, loss_list=None): for data_sample, pred_label in zip(data_samples, results_list): data_sample.pred_label = pred_label if loss_list is not None: - for data_sample, pred_loss in zip(data_sample, loss_list): - data_sample.pred_loss = pred_loss + for data_sample, pred_loss in zip(data_samples, loss_list): + data_sample.pred_loss = pred_loss['loss'] return data_samples \ No newline at end of file diff --git a/deep_vital/models/losses/mse_loss.py b/deep_vital/models/losses/mse_loss.py index dc86de2..f2f4be3 100644 --- a/deep_vital/models/losses/mse_loss.py +++ b/deep_vital/models/losses/mse_loss.py @@ -44,5 +44,5 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tenso # TODO: set float32 in dataset loader target = target.to(torch.float32) # loss_align = self.loss_mse(pred, target.detach()) - mse_loss = F.mse_loss(pred, target) + mse_loss = F.mse_loss(pred, target.detach()) return mse_loss \ No newline at end of file From 344ac21e6ec70e2f4fceaadbe3026d69ee5c11e7 Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Mon, 22 May 2023 06:49:21 +0000 Subject: [PATCH 12/13] [bugfix] compute val loss --- configs/resnet/resnet50_4xb128_bp_ppg.py | 4 +- deep_vital/evaluation/metrics/mae.py | 7 +- deep_vital/models/backbones/lstm.py | 44 +++++++ deep_vital/models/backbones/resnet_1d.py | 117 ------------------ deep_vital/models/blood_pressure/bp_lstm.py | 112 +++++++++++++++++ deep_vital/models/blood_pressure/bp_resnet.py | 1 - 6 files changed, 162 insertions(+), 123 deletions(-) create mode 100644 deep_vital/models/backbones/lstm.py create mode 100644 deep_vital/models/blood_pressure/bp_lstm.py diff --git a/configs/resnet/resnet50_4xb128_bp_ppg.py b/configs/resnet/resnet50_4xb128_bp_ppg.py index 8cd66ef..573bce3 100644 --- a/configs/resnet/resnet50_4xb128_bp_ppg.py +++ b/configs/resnet/resnet50_4xb128_bp_ppg.py @@ -76,9 +76,9 @@ default_hooks = dict( checkpoint=dict(type='CheckpointHook', interval=1, - by_epoch=True, save_best='pred_loss', rule='less'), + by_epoch=True, save_best='MAE/pred_loss', rule='less'), ) custom_hooks = [ - dict(type='EarlyStoppingHook', monitor='pred_loss', rule='less', + dict(type='EarlyStoppingHook', monitor='MAE/pred_loss', rule='less', min_delta=0.01, strict=False, check_finite=True, patience=5) ] diff --git a/deep_vital/evaluation/metrics/mae.py b/deep_vital/evaluation/metrics/mae.py index 87a9426..9e98f31 100644 --- a/deep_vital/evaluation/metrics/mae.py +++ b/deep_vital/evaluation/metrics/mae.py @@ -42,9 +42,10 @@ def compute_metrics(self, results): # loss = torch.stack([res['pred_loss'] for res in results]) # val_loss = loss.mean() # metrics['pred_loss'] = val_loss - sbp_pred, dbp_pred = pred - sbp_target = target[:, 0][..., None] - dbp_target = target[:, 1][..., None] + sbp_pred = pred[:, 0] + dbp_pred = pred[:, 1] + sbp_target = target[:, 0] + dbp_target = target[:, 1] sbp_mse_loss = F.mse_loss(sbp_pred, sbp_target.detach()) dbp_mse_loss = F.mse_loss(dbp_pred, dbp_target.detach()) metrics['pred_loss'] = sbp_mse_loss + dbp_mse_loss diff --git a/deep_vital/models/backbones/lstm.py b/deep_vital/models/backbones/lstm.py new file mode 100644 index 0000000..6095dba --- /dev/null +++ b/deep_vital/models/backbones/lstm.py @@ -0,0 +1,44 @@ +from torch import nn +from mmengine.model import BaseModel +from mmcv.cnn import build_conv_layer +from deep_vital.registry import MODELS + + +@MODELS.register_module() +class LSTMBackbone(BaseModel): + def __init__(self, in_channels, conv_cfg=dict(type='Conv1d'), init_cfg=None) -> None: + super().__init__(init_cfg=init_cfg) + self.relu = nn.ReLU(inplace=True) + + self.conv = build_conv_layer(conv_cfg, + in_channels, + 64, + kernel_size=5, + stride=1, + padding=[4, 0], + dilation=1, + bias=True) + + self.layer1 = nn.LSTM(64, 128, bidirectional=True) + self.layer2 = nn.LSTM(128, 128, bidirectional=True) + self.layer3 = nn.LSTM(128, 64, bidirectional=True) + + self.layer4 = nn.Linear(64, 512) + self.layer5 = nn.Linear(512, 256) + self.layer6 = nn.Linear(256, 128) + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.layer4(x) + x = self.relu(x) + x = self.layer5(x) + x = self.relu(x) + x = self.layer6(x) + x = self.relu(x) + return x diff --git a/deep_vital/models/backbones/resnet_1d.py b/deep_vital/models/backbones/resnet_1d.py index 09befba..cfb1a13 100644 --- a/deep_vital/models/backbones/resnet_1d.py +++ b/deep_vital/models/backbones/resnet_1d.py @@ -14,123 +14,6 @@ eps = 1.0e-5 -class identity_block(nn.Module): - - def __init__(self, dims, f, filters, dil=1) -> None: - super().__init__() - F1, F2, F3 = filters - - self.conv_2a = nn.Conv1d(dims, - F1, - kernel_size=1, - stride=1, - dilation=dil, - padding='valid') - self.bn_2a = nn.BatchNorm1d(F1, momentum=1 - 0.9) - self.relu = nn.ReLU(inplace=True) - - self.conv_2b = nn.Conv1d(F1, - F2, - kernel_size=f, - stride=1, - dilation=dil, - padding='same') - self.bn_2b = nn.BatchNorm1d(F2, momentum=1 - 0.9) - # RELU - - self.conv_2c = nn.Conv1d(F2, - F3, - kernel_size=1, - stride=1, - dilation=dil, - padding='valid') - self.bn_2c = nn.BatchNorm1d(F3, momentum=1 - 0.9) - - def forward(self, x): - x_shortcut = x - - x = self.conv_2a(x) - x = self.bn_2a(x) - x = self.relu(x) - - x = self.conv_2b(x) - x = self.bn_2b(x) - x = self.relu(x) - - x = self.conv_2c(x) - x = self.bn_2c(x) - - x = x + x_shortcut - x = self.relu(x) - return x - - -class convolutional_block(nn.Module): - - def __init__(self, dims, f, filters, s=2, dil=1) -> None: - super().__init__() - # Retrieve Filters - F1, F2, F3 = filters - - ##### MAIN PATH ##### - # First component of main path - self.conv_2a = nn.Conv1d(dims, F1, kernel_size=1, stride=s) - self.bn_2a = nn.BatchNorm1d(F1, momentum=1 - 0.9) - self.relu = nn.ReLU(inplace=True) - - # Second component of main path (≈3 lines) - self.conv_2b = nn.Conv1d(F1, - F2, - kernel_size=f, - stride=1, - dilation=dil, - padding='same') - self.bn_2b = nn.BatchNorm1d(F2, momentum=1 - 0.9) - - # Third component of main path (≈2 lines) - self.conv_2c = nn.Conv1d(F2, - F3, - kernel_size=1, - stride=1, - dilation=dil, - padding='valid') - self.bn_2c = nn.BatchNorm1d(F3, momentum=1 - 0.9) - - ##### SHORTCUT PATH #### (≈2 lines) - self.conv_1 = nn.Conv1d(dims, - F3, - kernel_size=1, - stride=s, - dilation=dil, - padding='valid') - self.bn_1 = nn.BatchNorm1d(F3, momentum=1 - 0.9) - - def forward(self, x): - x_shortcut = x - # First component of main path - x = self.conv_2a(x) - x = self.bn_2a(x) - x = self.relu(x) - - # Second component of main path (≈3 lines) - x = self.conv_2b(x) - x = self.bn_2b(x) - x = self.relu(x) - - # Third component of main path (≈2 lines) - x = self.conv_2c(x) - x = self.bn_2c(x) - - x_shortcut = self.conv_1(x_shortcut) - x_shortcut = self.bn_1(x_shortcut) - - # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines) - x = x + x_shortcut - x = self.relu(x) - # X = BatchNormalization(momentum = 0.9, name = bn_name_base + '2c')(X) - return x - - class Bottleneck1D(BaseModel): """Bottleneck block for ResNet1D. """ diff --git a/deep_vital/models/blood_pressure/bp_lstm.py b/deep_vital/models/blood_pressure/bp_lstm.py new file mode 100644 index 0000000..dad3216 --- /dev/null +++ b/deep_vital/models/blood_pressure/bp_lstm.py @@ -0,0 +1,112 @@ +from typing import List +import torch +from torch import nn +from mmengine.model import BaseModel +from deep_vital.registry import MODELS +from deep_vital.structures import DataSample + + +@MODELS.register_module() +class BPLSTM(BaseModel): + def __init__(self, + backbone, + neck=None, + head=None, + train_cfg=None, + test_cfg=None, + data_preprocessor=None, + init_cfg=None) -> None: + super().__init__(data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self.head = MODELS.build(head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + @property + def with_neck(self) -> bool: + """bool: whether the detector has a neck""" + return hasattr(self, 'neck') and self.neck is not None + + def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): + feats = self.extract_feat(batch_inputs) + + results_list = self.head.predict(feats, + batch_data_samples, + rescale=rescale) + # val_loss = self.head.loss(feats, batch_data_samples) + results_list = torch.cat([results_list[0], results_list[1]], dim=1) + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + def _forward(self, batch_inputs, batch_data_samples=None): + x = self.extract_feat(batch_inputs) + results = self.head.forward(x) + return results + + def extract_feat(self, batch_inputs): + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x + + def forward(self, inputs, data_samples, mode='tensor'): + if mode == 'tensor': + feats = self._forward(inputs) + return self.head(feats) if self.with_head else feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + + return self.head.loss(feats, data_samples) + + def add_pred_to_datasample(self, data_samples, results_list, loss_list=None): + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + for data_sample, pred_label in zip(data_samples, results_list): + data_sample.pred_label = pred_label + if loss_list is not None: + for data_sample, pred_loss in zip(data_samples, loss_list): + data_sample.pred_loss = pred_loss['loss'] + return data_samples diff --git a/deep_vital/models/blood_pressure/bp_resnet.py b/deep_vital/models/blood_pressure/bp_resnet.py index 0f1763c..467de3f 100644 --- a/deep_vital/models/blood_pressure/bp_resnet.py +++ b/deep_vital/models/blood_pressure/bp_resnet.py @@ -1,5 +1,4 @@ from typing import List -from abc import ABCMeta, abstractmethod import torch from torch import nn from mmengine.model import BaseModel From 9671ed8428c0c40fc86a62f54932b909028839db Mon Sep 17 00:00:00 2001 From: WANG Tengfei Date: Mon, 22 May 2023 07:21:10 +0000 Subject: [PATCH 13/13] [feature] change val loss arg name --- configs/resnet/resnet50_4xb128_bp_ppg.py | 7 +++++-- deep_vital/evaluation/metrics/mae.py | 14 +++----------- deep_vital/version.py | 2 +- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/configs/resnet/resnet50_4xb128_bp_ppg.py b/configs/resnet/resnet50_4xb128_bp_ppg.py index 573bce3..0c73e1f 100644 --- a/configs/resnet/resnet50_4xb128_bp_ppg.py +++ b/configs/resnet/resnet50_4xb128_bp_ppg.py @@ -38,7 +38,7 @@ ) val_evaluator = [ dict(type='MAE', gt_key='gt_label', - pred_key='pred_label', loss_key='pred_loss'), + pred_key='pred_label', compute_loss=True), # dict(type='BlandAltmanPlot') ] test_dataloader = dict( @@ -51,7 +51,10 @@ pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), ) -test_evaluator = val_evaluator +test_evaluator = [ + dict(type='MAE', gt_key='gt_label', + pred_key='pred_label') +] # optimizer optim_wrapper = dict(optimizer=dict(type='Adam', diff --git a/deep_vital/evaluation/metrics/mae.py b/deep_vital/evaluation/metrics/mae.py index 9e98f31..cbb9272 100644 --- a/deep_vital/evaluation/metrics/mae.py +++ b/deep_vital/evaluation/metrics/mae.py @@ -2,11 +2,8 @@ import torch import torch.nn.functional as F from mmengine.evaluator import BaseMetric -from mmengine import MessageHub from deep_vital.registry import METRICS -message_hub = MessageHub.get_current_instance() - @METRICS.register_module() class MAE(BaseMetric): @@ -16,13 +13,13 @@ class MAE(BaseMetric): def __init__(self, gt_key='gt_label', pred_key='pred_label', - loss_key=None, + compute_loss=False, collect_device: str = 'cpu', prefix: Optional[str] = None): super().__init__(collect_device=collect_device, prefix=prefix) self.gt_key = gt_key self.pred_key = pred_key - self.loss_key = loss_key + self.compute_loss = compute_loss def process(self, data_batch, data_samples): for data_sample in data_samples: @@ -30,18 +27,13 @@ def process(self, data_batch, data_samples): 'pred_label': data_sample[self.pred_key], 'gt_label': data_sample[self.gt_key] } - # if self.loss_key: - # result['pred_loss'] = data_sample[self.loss_key] self.results.append(result) def compute_metrics(self, results): metrics = {} target = torch.stack([res['gt_label'] for res in results]) pred = torch.stack([res['pred_label'] for res in results]) - if self.loss_key: - # loss = torch.stack([res['pred_loss'] for res in results]) - # val_loss = loss.mean() - # metrics['pred_loss'] = val_loss + if self.compute_loss: sbp_pred = pred[:, 0] dbp_pred = pred[:, 1] sbp_target = target[:, 0] diff --git a/deep_vital/version.py b/deep_vital/version.py index 9d24ae5..5fdda83 100644 --- a/deep_vital/version.py +++ b/deep_vital/version.py @@ -1,4 +1,4 @@ -__version__ = '0.1.0' +__version__ = '1.0.0' short_version = __version__