forked from Segment-Something/segm-models-public
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
23 lines (21 loc) · 739 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import segmentation_models_pytorch as smp
def get_model(model_name: str, encoder_name: str, encoder_weights: str, activation: str,
classes=1, in_channels=3):
# model name should be unique!
if model_name == 'unet':
model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation
)
if model_name == 'deeplabv3+':
model = smp.DeepLabV3Plus(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation
)
return model