Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #2

Merged
merged 13 commits into from
May 22, 2023
28 changes: 28 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: deploy

on: push

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install torch
run: pip install torch
- name: Install wheel
run: pip install wheel
- name: Build Deep Vital
run: python setup.py sdist bdist_wheel
- name: Publish distribution to PyPI
run: |
pip install twine
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# deep_vital
# Deep Vital

Pytorch version of [non-invasive-bp-estimation-using-deep-learning](https://github.com/Fabian-Sc85/non-invasive-bp-estimation-using-deep-learning)

pretrained model can be found from [zendo](https://zenodo.org/record/7948098)

## Models
- [ResNet](configs/resnet/README.md)
23 changes: 23 additions & 0 deletions configs/resnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ResNet Blood Pressure Model

> [Assessment of Non-Invasive Blood Pressure Prediction from PPG and rPPG Signals Using Deep Learning](https://readpaper.com/paper/3198128029)

## Abstract

## Results and Models
| Model | MAE-SBP | MAE-DBP |
| :---: | :-----: | :-----: |
| ResNet501D | - | - |


## Citation

```latex
@inproceedings{schrumpf2021assessment,
title={Assessment of deep learning based blood pressure prediction from PPG and rPPG signals},
author={Schrumpf, Fabian and Frenzel, Patrick and Aust, Christoph and Osterhoff, Georg and Fuchs, Mirco},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={3820--3830},
year={2021}
}
```
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
_base_ = ['_base_/default_runtime.py']
_base_ = ['../_base_/default_runtime.py']

model = dict(type='BPResNet1D',
data_preprocessor=dict(type='RppgDataPreprocessor'),
data_preprocessor=dict(type='DataPreprocessor'),
backbone=dict(
type='ResNet1D',
depth=50,
frozen_stages=4,
init_cfg=dict(
type='Pretrained',
checkpoint='data/resnet_ppg_nonmixed_backbone_v2.pb')),
checkpoint='data/resnet_ppg_nonmixed_backbone.pth')),
neck=dict(type='AveragePooling'),
head=dict(type='BPDenseHead', loss=dict(type='MSELoss')))

Expand All @@ -20,11 +20,11 @@
dict(type='PackInputs', input_key='rppg'),
]
train_dataloader = dict(
batch_size=16,
batch_size=32,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/rPPG-BP-UKL_rppg_7s.h5',
used_subjects=[7., 8., 10., 13., 14., 19., 21., 23., 24., 27., 33., 35., 36., 44., 46., 48.],
used_idx_file='data/train.txt',
data_prefix='',
test_mode=False,
pipeline=train_pipeline),
Expand All @@ -35,14 +35,24 @@
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/rPPG-BP-UKL_rppg_7s.h5',
used_subjects=[6],
used_idx_file='data/val.txt',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='MAE', gt_key='gt_label', pred_key='pred_label')
test_dataloader = val_dataloader
test_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/rPPG-BP-UKL_rppg_7s.h5',
used_idx_file='data/test.txt',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = val_evaluator

# optimizer
Expand All @@ -65,3 +75,8 @@
dict(type='TensorboardVisBackend')
],
)

default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=True, save_best='loss', rule='less'),
)
custom_hooks = [dict(type='EarlyStoppingHook', monitor='loss', rule='less', min_delta=0.01, strict=False, check_finite=True, patience=5)]
87 changes: 87 additions & 0 deletions configs/resnet/resnet50_4xb128_bp_ppg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
_base_ = ['../_base_/default_runtime.py']

model = dict(type='BPResNet1D',
data_preprocessor=dict(type='DataPreprocessor'),
backbone=dict(
type='ResNet1D',
depth=50,
),
neck=dict(type='AveragePooling'),
head=dict(type='BPDenseHead', loss=dict(type='MSELoss')))

dataset_type = 'PpgData'
train_pipeline = [
dict(type='PackInputs', input_key='ppg'),
]
test_pipeline = [
dict(type='PackInputs', input_key='ppg'),
]
train_dataloader = dict(
batch_size=128,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/mimic-iii_data/train.h5',
data_prefix='',
test_mode=False,
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/mimic-iii_data/val.h5',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = [
dict(type='MAE', gt_key='gt_label',
pred_key='pred_label', compute_loss=True),
# dict(type='BlandAltmanPlot')
]
test_dataloader = dict(
batch_size=64,
num_workers=2,
dataset=dict(type=dataset_type,
ann_file='data/mimic-iii_data/test.h5',
data_prefix='',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = [
dict(type='MAE', gt_key='gt_label',
pred_key='pred_label')
]

# optimizer
optim_wrapper = dict(optimizer=dict(type='Adam',
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
amsgrad=False))

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1)
val_cfg = dict()
test_cfg = dict()

visualizer = dict(
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend')
],
)

default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1,
by_epoch=True, save_best='MAE/pred_loss', rule='less'),
)
custom_hooks = [
dict(type='EarlyStoppingHook', monitor='MAE/pred_loss', rule='less',
min_delta=0.01, strict=False, check_finite=True, patience=5)
]
3 changes: 2 additions & 1 deletion deep_vital/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ppg_signal_dataset import RppgData
from .ppg_signal_dataset import PpgData
from .rppg_signal_dataset import RppgData
from .transforms import *
29 changes: 15 additions & 14 deletions deep_vital/datasets/ppg_signal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ def expanduser(path):


@DATASETS.register_module()
class RppgData(BaseDataset):
class PpgData(BaseDataset):
_signal_name = 'ppg'
def __init__(self,
ann_file,
used_subjects=(),
used_idx_file=None,
metainfo=None,
data_root='',
data_prefix='',
Expand All @@ -32,9 +33,9 @@ def __init__(self,
lazy_init=False,
max_refetch=1000):
self.label = None
self.rppg = None
self.signal = None
self.subject_idx = None
self.used_subjects=used_subjects
self.used_idx_file=used_idx_file
self.num = None

if isinstance(data_prefix, str):
Expand Down Expand Up @@ -65,24 +66,24 @@ def load_data_list(self):
assert os.path.exists(self.ann_file), self.ann_file

data = h5py.File(self.ann_file, 'r')
self.label = np.array(data.get('/label')).T
self.rppg = np.array(data.get('/rppg')).T
self.subject_idx = np.array(data.get('/subject_idx'), dtype=int)[0, :]
self.label = np.array(data.get('/label'))
self.signal = np.array(data.get(f'/{self._signal_name}'))
self.subject_idx = np.array(data.get('/subject_idx'), dtype=int)
subjects_list = np.unique(self.subject_idx)

if self.used_subjects:
idx_used = np.where(np.isin(self.subject_idx, self.used_subjects))[-1]
if self.used_idx_file and os.path.exists(self.used_idx_file):
idx_used = np.loadtxt(self.used_idx_file, dtype=np.int64, delimiter=',')
self.label = self.label[idx_used]
self.rppg = self.rppg[idx_used]
self.signal = self.signal[idx_used]
self.subject_idx = self.subject_idx[idx_used]

self.num = self.subject_idx.shape[0]

data_list = []
for _label, _rppg, _subject_idx in zip(self.label, self.rppg, self.subject_idx):
for _label, _signal, _subject_idx in zip(self.label, self.signal, self.subject_idx):
# sbp_label = _label[0]
# dbp_label = _label[1]
# _label = _label.reshape((1, 2, 1))
info = {'gt_label': _label, 'rppg': _rppg, 'subject_idx': _subject_idx}
info = {'gt_label': _label, f'{self._signal_name}': _signal, 'subject_idx': _subject_idx}
data_list.append(info)
return data_list
8 changes: 8 additions & 0 deletions deep_vital/datasets/rppg_signal_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

from deep_vital.registry import DATASETS
from .ppg_signal_dataset import PpgData


@DATASETS.register_module()
class RppgData(PpgData):
_signal_name = 'rppg'
Empty file.
3 changes: 2 additions & 1 deletion deep_vital/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mae import MAE
from .mae import MAE
from .bland_altman_plot import BlandAltmanPlot
74 changes: 74 additions & 0 deletions deep_vital/evaluation/metrics/bland_altman_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import PIL.Image as Image
from mmengine.evaluator import BaseMetric
from mmengine.visualization import Visualizer
from mmengine import MessageHub
from deep_vital.registry import METRICS

message_hub = MessageHub.get_current_instance()


def bland_altman_plot(data1, data2, *args, **kwargs):
data1 = np.asarray(data1.numpy())
data2 = np.asarray(data2.numpy())
mean = np.mean([data1, data2], axis=0)
diff = data1 - data2 # Difference between data1 and data2
md = np.mean(diff) # Mean of the difference
sd = np.std(diff, axis=0) # Standard deviation of the difference

plt.cla()
plt.scatter(mean, diff, *args, **kwargs)
plt.axhline(md, color='gray', linestyle='-')
plt.axhline(md + 1.96*sd, color='gray', linestyle='--')
plt.axhline(md - 1.96*sd, color='gray', linestyle='--')
canvas = FigureCanvasAgg(plt.gcf())
canvas.draw()
w, h = canvas.get_width_height()
buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8)
buf.shape = (w, h, 4)
buf = np.roll(buf, 3, axis=2)
image = Image.frombytes('RGBA', (w, h), buf.tostring())
image = np.asarray(image)
rgb_image = image[:, :, :3]
plt.close()
return rgb_image


@METRICS.register_module()
class BlandAltmanPlot(BaseMetric):
metric = 'BlandAltmanPlot'
default_prefix = 'BlandAltmanPlot'

def __init__(self,
gt_key='gt_label',
pred_key='pred_label',
collect_device: str = 'cpu',
prefix: Optional[str] = None):
super().__init__(collect_device=collect_device, prefix=prefix)
self.gt_key = gt_key
self.pred_key = pred_key
self.visualizer = Visualizer.get_current_instance()

def process(self, data_batch, data_samples):
for data_sample in data_samples:
result = {
'pred_label': data_sample[self.pred_key],
'gt_label': data_sample[self.gt_key],
}
self.results.append(result)

def compute_metrics(self, results):
metrics = {}
target = torch.stack([res['gt_label'] for res in results])
pred = torch.stack([res['pred_label'] for res in results])

sbp_BAP = bland_altman_plot(pred[:, 0], target[:, 0])
dbp_BAP = bland_altman_plot(pred[:, 1], target[:, 1])
current_step = message_hub.get_info('epoch')
self.visualizer.add_image('sbp_BAP', sbp_BAP, current_step)
self.visualizer.add_image('dbp_BAP', dbp_BAP, current_step)
return metrics
Loading