Skip to content

Commit

Permalink
Merge pull request #2 from wang-tf:develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
wang-tf authored May 22, 2023
2 parents 8d83b69 + 9671ed8 commit 7e73dce
Show file tree
Hide file tree
Showing 26 changed files with 773 additions and 197 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: deploy

on: push

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install torch
run: pip install torch
- name: Install wheel
run: pip install wheel
- name: Build Deep Vital
run: python setup.py sdist bdist_wheel
- name: Publish distribution to PyPI
run: |
pip install twine
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# deep_vital
# Deep Vital

Pytorch version of [non-invasive-bp-estimation-using-deep-learning](https://github.com/Fabian-Sc85/non-invasive-bp-estimation-using-deep-learning)

pretrained model can be found from [zendo](https://zenodo.org/record/7948098)

## Models
- [ResNet](configs/resnet/README.md)
23 changes: 23 additions & 0 deletions configs/resnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ResNet Blood Pressure Model

> [Assessment of Non-Invasive Blood Pressure Prediction from PPG and rPPG Signals Using Deep Learning](https://readpaper.com/paper/3198128029)
## Abstract

## Results and Models
| Model | MAE-SBP | MAE-DBP |
| :---: | :-----: | :-----: |
| ResNet501D | - | - |


## Citation

```latex
@inproceedings{schrumpf2021assessment,
title={Assessment of deep learning based blood pressure prediction from PPG and rPPG signals},
author={Schrumpf, Fabian and Frenzel, Patrick and Aust, Christoph and Osterhoff, Georg and Fuchs, Mirco},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={3820--3830},
year={2021}
}
```
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
_base_ = ['_base_/default_runtime.py']
_base_ = ['../_base_/default_runtime.py']

model = dict(type='BPResNet1D',
data_preprocessor=dict(type='RppgDataPreprocessor'),
data_preprocessor=dict(type='DataPreprocessor'),
backbone=dict(
type='ResNet1D',
depth=50,
frozen_stages=4,
init_cfg=dict(
type='Pretrained',
checkpoint='data/resnet_ppg_nonmixed_backbone_v2.pb')),
checkpoint='data/resnet_ppg_nonmixed_backbone.pth')),
neck=dict(type='AveragePooling'),
head=dict(type='BPDenseHead', loss=dict(type='MSELoss')))

Expand All @@ -20,11 +20,11 @@
dict(type='PackInputs', input_key='rppg'),
]
train_dataloader = dict(
batch_size=16,
batch_size=32,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/rPPG-BP-UKL_rppg_7s.h5',
used_subjects=[7., 8., 10., 13., 14., 19., 21., 23., 24., 27., 33., 35., 36., 44., 46., 48.],
used_idx_file='data/train.txt',
data_prefix='',
test_mode=False,
pipeline=train_pipeline),
Expand All @@ -35,14 +35,24 @@
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/rPPG-BP-UKL_rppg_7s.h5',
used_subjects=[6],
used_idx_file='data/val.txt',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='MAE', gt_key='gt_label', pred_key='pred_label')
test_dataloader = val_dataloader
test_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/rPPG-BP-UKL_rppg_7s.h5',
used_idx_file='data/test.txt',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = val_evaluator

# optimizer
Expand All @@ -65,3 +75,8 @@
dict(type='TensorboardVisBackend')
],
)

default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=True, save_best='loss', rule='less'),
)
custom_hooks = [dict(type='EarlyStoppingHook', monitor='loss', rule='less', min_delta=0.01, strict=False, check_finite=True, patience=5)]
87 changes: 87 additions & 0 deletions configs/resnet/resnet50_4xb128_bp_ppg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
_base_ = ['../_base_/default_runtime.py']

model = dict(type='BPResNet1D',
data_preprocessor=dict(type='DataPreprocessor'),
backbone=dict(
type='ResNet1D',
depth=50,
),
neck=dict(type='AveragePooling'),
head=dict(type='BPDenseHead', loss=dict(type='MSELoss')))

dataset_type = 'PpgData'
train_pipeline = [
dict(type='PackInputs', input_key='ppg'),
]
test_pipeline = [
dict(type='PackInputs', input_key='ppg'),
]
train_dataloader = dict(
batch_size=128,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/mimic-iii_data/train.h5',
data_prefix='',
test_mode=False,
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/mimic-iii_data/val.h5',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = [
dict(type='MAE', gt_key='gt_label',
pred_key='pred_label', compute_loss=True),
# dict(type='BlandAltmanPlot')
]
test_dataloader = dict(
batch_size=64,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/mimic-iii_data/test.h5',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = [
dict(type='MAE', gt_key='gt_label',
pred_key='pred_label')
]

# optimizer
optim_wrapper = dict(optimizer=dict(type='Adam',
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
amsgrad=False))

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1)
val_cfg = dict()
test_cfg = dict()

visualizer = dict(
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend')
],
)

default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1,
by_epoch=True, save_best='MAE/pred_loss', rule='less'),
)
custom_hooks = [
dict(type='EarlyStoppingHook', monitor='MAE/pred_loss', rule='less',
min_delta=0.01, strict=False, check_finite=True, patience=5)
]
3 changes: 2 additions & 1 deletion deep_vital/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ppg_signal_dataset import RppgData
from .ppg_signal_dataset import PpgData
from .rppg_signal_dataset import RppgData
from .transforms import *
29 changes: 15 additions & 14 deletions deep_vital/datasets/ppg_signal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ def expanduser(path):


@DATASETS.register_module()
class RppgData(BaseDataset):
class PpgData(BaseDataset):
_signal_name = 'ppg'
def __init__(self,
ann_file,
used_subjects=(),
used_idx_file=None,
metainfo=None,
data_root='',
data_prefix='',
Expand All @@ -32,9 +33,9 @@ def __init__(self,
lazy_init=False,
max_refetch=1000):
self.label = None
self.rppg = None
self.signal = None
self.subject_idx = None
self.used_subjects=used_subjects
self.used_idx_file=used_idx_file
self.num = None

if isinstance(data_prefix, str):
Expand Down Expand Up @@ -65,24 +66,24 @@ def load_data_list(self):
assert os.path.exists(self.ann_file), self.ann_file

data = h5py.File(self.ann_file, 'r')
self.label = np.array(data.get('/label')).T
self.rppg = np.array(data.get('/rppg')).T
self.subject_idx = np.array(data.get('/subject_idx'), dtype=int)[0, :]
self.label = np.array(data.get('/label'))
self.signal = np.array(data.get(f'/{self._signal_name}'))
self.subject_idx = np.array(data.get('/subject_idx'), dtype=int)
subjects_list = np.unique(self.subject_idx)

if self.used_subjects:
idx_used = np.where(np.isin(self.subject_idx, self.used_subjects))[-1]
if self.used_idx_file and os.path.exists(self.used_idx_file):
idx_used = np.loadtxt(self.used_idx_file, dtype=np.int64, delimiter=',')
self.label = self.label[idx_used]
self.rppg = self.rppg[idx_used]
self.signal = self.signal[idx_used]
self.subject_idx = self.subject_idx[idx_used]

self.num = self.subject_idx.shape[0]

data_list = []
for _label, _rppg, _subject_idx in zip(self.label, self.rppg, self.subject_idx):
for _label, _signal, _subject_idx in zip(self.label, self.signal, self.subject_idx):
# sbp_label = _label[0]
# dbp_label = _label[1]
# _label = _label.reshape((1, 2, 1))
info = {'gt_label': _label, 'rppg': _rppg, 'subject_idx': _subject_idx}
info = {'gt_label': _label, f'{self._signal_name}': _signal, 'subject_idx': _subject_idx}
data_list.append(info)
return data_list
8 changes: 8 additions & 0 deletions deep_vital/datasets/rppg_signal_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

from deep_vital.registry import DATASETS
from .ppg_signal_dataset import PpgData


@DATASETS.register_module()
class RppgData(PpgData):
_signal_name = 'rppg'
Empty file.
3 changes: 2 additions & 1 deletion deep_vital/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mae import MAE
from .mae import MAE
from .bland_altman_plot import BlandAltmanPlot
74 changes: 74 additions & 0 deletions deep_vital/evaluation/metrics/bland_altman_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import PIL.Image as Image
from mmengine.evaluator import BaseMetric
from mmengine.visualization import Visualizer
from mmengine import MessageHub
from deep_vital.registry import METRICS

message_hub = MessageHub.get_current_instance()


def bland_altman_plot(data1, data2, *args, **kwargs):
data1 = np.asarray(data1.numpy())
data2 = np.asarray(data2.numpy())
mean = np.mean([data1, data2], axis=0)
diff = data1 - data2 # Difference between data1 and data2
md = np.mean(diff) # Mean of the difference
sd = np.std(diff, axis=0) # Standard deviation of the difference

plt.cla()
plt.scatter(mean, diff, *args, **kwargs)
plt.axhline(md, color='gray', linestyle='-')
plt.axhline(md + 1.96*sd, color='gray', linestyle='--')
plt.axhline(md - 1.96*sd, color='gray', linestyle='--')
canvas = FigureCanvasAgg(plt.gcf())
canvas.draw()
w, h = canvas.get_width_height()
buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8)
buf.shape = (w, h, 4)
buf = np.roll(buf, 3, axis=2)
image = Image.frombytes('RGBA', (w, h), buf.tostring())
image = np.asarray(image)
rgb_image = image[:, :, :3]
plt.close()
return rgb_image


@METRICS.register_module()
class BlandAltmanPlot(BaseMetric):
metric = 'BlandAltmanPlot'
default_prefix = 'BlandAltmanPlot'

def __init__(self,
gt_key='gt_label',
pred_key='pred_label',
collect_device: str = 'cpu',
prefix: Optional[str] = None):
super().__init__(collect_device=collect_device, prefix=prefix)
self.gt_key = gt_key
self.pred_key = pred_key
self.visualizer = Visualizer.get_current_instance()

def process(self, data_batch, data_samples):
for data_sample in data_samples:
result = {
'pred_label': data_sample[self.pred_key],
'gt_label': data_sample[self.gt_key],
}
self.results.append(result)

def compute_metrics(self, results):
metrics = {}
target = torch.stack([res['gt_label'] for res in results])
pred = torch.stack([res['pred_label'] for res in results])

sbp_BAP = bland_altman_plot(pred[:, 0], target[:, 0])
dbp_BAP = bland_altman_plot(pred[:, 1], target[:, 1])
current_step = message_hub.get_info('epoch')
self.visualizer.add_image('sbp_BAP', sbp_BAP, current_step)
self.visualizer.add_image('dbp_BAP', dbp_BAP, current_step)
return metrics
Loading

0 comments on commit 7e73dce

Please sign in to comment.