-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from wang-tf:develop
Develop
- Loading branch information
Showing
26 changed files
with
773 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: deploy | ||
|
||
on: push | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build-n-publish: | ||
runs-on: ubuntu-latest | ||
if: startsWith(github.event.ref, 'refs/tags') | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python 3.7 | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: 3.7 | ||
- name: Install torch | ||
run: pip install torch | ||
- name: Install wheel | ||
run: pip install wheel | ||
- name: Build Deep Vital | ||
run: python setup.py sdist bdist_wheel | ||
- name: Publish distribution to PyPI | ||
run: | | ||
pip install twine | ||
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
# deep_vital | ||
# Deep Vital | ||
|
||
Pytorch version of [non-invasive-bp-estimation-using-deep-learning](https://github.com/Fabian-Sc85/non-invasive-bp-estimation-using-deep-learning) | ||
|
||
pretrained model can be found from [zendo](https://zenodo.org/record/7948098) | ||
|
||
## Models | ||
- [ResNet](configs/resnet/README.md) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# ResNet Blood Pressure Model | ||
|
||
> [Assessment of Non-Invasive Blood Pressure Prediction from PPG and rPPG Signals Using Deep Learning](https://readpaper.com/paper/3198128029) | ||
## Abstract | ||
|
||
## Results and Models | ||
| Model | MAE-SBP | MAE-DBP | | ||
| :---: | :-----: | :-----: | | ||
| ResNet501D | - | - | | ||
|
||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{schrumpf2021assessment, | ||
title={Assessment of deep learning based blood pressure prediction from PPG and rPPG signals}, | ||
author={Schrumpf, Fabian and Frenzel, Patrick and Aust, Christoph and Osterhoff, Georg and Fuchs, Mirco}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={3820--3830}, | ||
year={2021} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
_base_ = ['../_base_/default_runtime.py'] | ||
|
||
model = dict(type='BPResNet1D', | ||
data_preprocessor=dict(type='DataPreprocessor'), | ||
backbone=dict( | ||
type='ResNet1D', | ||
depth=50, | ||
), | ||
neck=dict(type='AveragePooling'), | ||
head=dict(type='BPDenseHead', loss=dict(type='MSELoss'))) | ||
|
||
dataset_type = 'PpgData' | ||
train_pipeline = [ | ||
dict(type='PackInputs', input_key='ppg'), | ||
] | ||
test_pipeline = [ | ||
dict(type='PackInputs', input_key='ppg'), | ||
] | ||
train_dataloader = dict( | ||
batch_size=128, | ||
num_workers=2, | ||
dataset=dict(type=dataset_type, | ||
ann_file='data/mimic-iii_data/train.h5', | ||
data_prefix='', | ||
test_mode=False, | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
) | ||
val_dataloader = dict( | ||
batch_size=64, | ||
num_workers=2, | ||
dataset=dict(type=dataset_type, | ||
ann_file='data/mimic-iii_data/val.h5', | ||
data_prefix='', | ||
test_mode=True, | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
) | ||
val_evaluator = [ | ||
dict(type='MAE', gt_key='gt_label', | ||
pred_key='pred_label', compute_loss=True), | ||
# dict(type='BlandAltmanPlot') | ||
] | ||
test_dataloader = dict( | ||
batch_size=64, | ||
num_workers=2, | ||
dataset=dict(type=dataset_type, | ||
ann_file='data/mimic-iii_data/test.h5', | ||
data_prefix='', | ||
test_mode=True, | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
) | ||
test_evaluator = [ | ||
dict(type='MAE', gt_key='gt_label', | ||
pred_key='pred_label') | ||
] | ||
|
||
# optimizer | ||
optim_wrapper = dict(optimizer=dict(type='Adam', | ||
lr=0.001, | ||
betas=(0.9, 0.999), | ||
eps=1e-08, | ||
weight_decay=0, | ||
amsgrad=False)) | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
visualizer = dict( | ||
type='UniversalVisualizer', | ||
vis_backends=[ | ||
dict(type='LocalVisBackend'), | ||
dict(type='TensorboardVisBackend') | ||
], | ||
) | ||
|
||
default_hooks = dict( | ||
checkpoint=dict(type='CheckpointHook', interval=1, | ||
by_epoch=True, save_best='MAE/pred_loss', rule='less'), | ||
) | ||
custom_hooks = [ | ||
dict(type='EarlyStoppingHook', monitor='MAE/pred_loss', rule='less', | ||
min_delta=0.01, strict=False, check_finite=True, patience=5) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .ppg_signal_dataset import RppgData | ||
from .ppg_signal_dataset import PpgData | ||
from .rppg_signal_dataset import RppgData | ||
from .transforms import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
from deep_vital.registry import DATASETS | ||
from .ppg_signal_dataset import PpgData | ||
|
||
|
||
@DATASETS.register_module() | ||
class RppgData(PpgData): | ||
_signal_name = 'rppg' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .mae import MAE | ||
from .mae import MAE | ||
from .bland_altman_plot import BlandAltmanPlot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Optional | ||
import torch | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from matplotlib.backends.backend_agg import FigureCanvasAgg | ||
import PIL.Image as Image | ||
from mmengine.evaluator import BaseMetric | ||
from mmengine.visualization import Visualizer | ||
from mmengine import MessageHub | ||
from deep_vital.registry import METRICS | ||
|
||
message_hub = MessageHub.get_current_instance() | ||
|
||
|
||
def bland_altman_plot(data1, data2, *args, **kwargs): | ||
data1 = np.asarray(data1.numpy()) | ||
data2 = np.asarray(data2.numpy()) | ||
mean = np.mean([data1, data2], axis=0) | ||
diff = data1 - data2 # Difference between data1 and data2 | ||
md = np.mean(diff) # Mean of the difference | ||
sd = np.std(diff, axis=0) # Standard deviation of the difference | ||
|
||
plt.cla() | ||
plt.scatter(mean, diff, *args, **kwargs) | ||
plt.axhline(md, color='gray', linestyle='-') | ||
plt.axhline(md + 1.96*sd, color='gray', linestyle='--') | ||
plt.axhline(md - 1.96*sd, color='gray', linestyle='--') | ||
canvas = FigureCanvasAgg(plt.gcf()) | ||
canvas.draw() | ||
w, h = canvas.get_width_height() | ||
buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8) | ||
buf.shape = (w, h, 4) | ||
buf = np.roll(buf, 3, axis=2) | ||
image = Image.frombytes('RGBA', (w, h), buf.tostring()) | ||
image = np.asarray(image) | ||
rgb_image = image[:, :, :3] | ||
plt.close() | ||
return rgb_image | ||
|
||
|
||
@METRICS.register_module() | ||
class BlandAltmanPlot(BaseMetric): | ||
metric = 'BlandAltmanPlot' | ||
default_prefix = 'BlandAltmanPlot' | ||
|
||
def __init__(self, | ||
gt_key='gt_label', | ||
pred_key='pred_label', | ||
collect_device: str = 'cpu', | ||
prefix: Optional[str] = None): | ||
super().__init__(collect_device=collect_device, prefix=prefix) | ||
self.gt_key = gt_key | ||
self.pred_key = pred_key | ||
self.visualizer = Visualizer.get_current_instance() | ||
|
||
def process(self, data_batch, data_samples): | ||
for data_sample in data_samples: | ||
result = { | ||
'pred_label': data_sample[self.pred_key], | ||
'gt_label': data_sample[self.gt_key], | ||
} | ||
self.results.append(result) | ||
|
||
def compute_metrics(self, results): | ||
metrics = {} | ||
target = torch.stack([res['gt_label'] for res in results]) | ||
pred = torch.stack([res['pred_label'] for res in results]) | ||
|
||
sbp_BAP = bland_altman_plot(pred[:, 0], target[:, 0]) | ||
dbp_BAP = bland_altman_plot(pred[:, 1], target[:, 1]) | ||
current_step = message_hub.get_info('epoch') | ||
self.visualizer.add_image('sbp_BAP', sbp_BAP, current_step) | ||
self.visualizer.add_image('dbp_BAP', dbp_BAP, current_step) | ||
return metrics |
Oops, something went wrong.