Source code for thelper.nn.segmentation.base
import torch
import thelper.nn
[docs]class SegmModelBase(thelper.nn.utils.Module):
"""
Base wrapper class for specialized segmentation models.
"""
model_cls = None # type: type
in_channels = None # type: int
[docs] def __init__(self, task, pretrained=False):
"""
.. note::
-
"""
# note: parameter "num" goes from 0 (for EfficientNet-b0) to 7 (for EfficientNet-b7)"""
# note: must always forward args to base class to keep backup
super().__init__(task, **{k: v for k, v in vars().items()
if k not in ["self", "task", "__class__"]})
self.num_classes = None
self.model = None # will be instantiated in set_task using model_cls
self.pretrained = pretrained
self.set_task(task)
[docs] def forward(self, x):
return self.model(x)
[docs] def set_task(self, task):
assert isinstance(task, thelper.tasks.Segmentation), \
"invalid task ({} currently only supports Segmentation)".format(type(self).__name__)
num_classes = len(task.class_names)
self.model = self.model_cls(pretrained=self.pretrained)
if num_classes != self.num_classes:
# Only the last layer is reinit, may all the classifier part should be reinit
self.model.classifier[4] = torch.nn.Conv2d(
in_channels=self.in_channels,
out_channels=num_classes,
kernel_size=(1, 1),
stride=(1, 1),
)
self.num_classes = num_classes