We usually define a neural network in a deep learning task as a model, and this model is the core of an algorithm. MMEngine abstracts a unified model BaseModel to standardize the interfaces for training, testing and other processes. All models implemented by MMSegmentation inherit from BaseModel
, and in MMSegmentation we implemented forward and added some functions for the semantic segmentation algorithm.
In MMSegmentation, we abstract the network architecture as a Segmentor, it is a model that contains all components of a network. We have already implemented EncoderDecoder and CascadeEncoderDecoder, which typically consist of Data preprocessor, Backbone, Decode head and Auxiliary head.
Data preprocessor is the part that copies data to the target device and preprocesses the data into the model input format.
Backbone is the part that transforms an image to feature maps, such as a ResNet-50 without the last fully connected layer.
Neck is the part that connects the backbone and heads. It performs some refinements or reconfigurations on the raw feature maps produced by the backbone. An example is Feature Pyramid Network (FPN).
Decode head is the part that transforms the feature maps into a segmentation mask, such as PSPNet.
Auxiliary head is an optional component that transforms the feature maps into segmentation masks which only used for computing auxiliary losses.
MMSegmentation wraps BaseModel
and implements the BaseSegmentor class, which mainly provides the interfaces forward
, train_step
, val_step
and test_step
. The following will introduce these interfaces in detail.
The forward
method returns losses or predictions of training, validation, testing, and a simple inference process.
The method should accept three modes: "tensor", "predict" and "loss":
nn.Module
.SegDataSample
.dict
of losses according to the given inputs and data samples.Note: SegDataSample is a data structure interface of MMSegmentation, it is used as an interface between different components. SegDataSample
implements the abstract data element mmengine.structures.BaseDataElement
, please refer to the SegDataSample documentation and data element documentation in MMEngine for more information.
Note that this method doesn't handle either backpropagation or optimizer updating, which are done in the method train_step
.
Parameters:
metainfo
and gt_sem_seg
. Default to None.Returns:
dict
or list
:
mode == "loss"
, return a dict
of loss tensor used for backward and logging.mode == "predict"
, return a list
of SegDataSample
, the inference results will be incrementally added to the data_sample
parameter passed to the forward method, each SegDataSample
contains the following keys:
PixelData
): Prediction of semantic segmentation.PixelData
): Predicted logits of semantic segmentation before normalization.mode == "tensor"
, return a tensor
or tuple of tensor
or dict
of tensor
for custom use.We briefly describe the fields of the model's configuration in the config documentation, here we elaborate on the model.test_cfg
field. model.test_cfg
is used to control forward behavior, the forward
method in "predict"
mode can run in two modes:
whole_inference
: If cfg.model.test_cfg.mode == 'whole'
, model will inference with full images.
An whole_inference
mode example config:
model = dict(
type='EncoderDecoder'
...
test_cfg=dict(mode='whole')
)
slide_inference
: If cfg.model.test_cfg.mode == 'slide'
, model will inference by sliding-window. Note: if you select the slide
mode, cfg.model.test_cfg.stride
and cfg.model.test_cfg.crop_size
should also be specified.
An slide_inference
mode example config:
model = dict(
type='EncoderDecoder'
...
test_cfg=dict(mode='slide', crop_size=256, stride=170)
)
The train_step
method calls the forward interface of the loss
mode to get the loss dict
. The BaseModel
class implements the default model training process including preprocessing, model forward propagation, loss calculation, optimization, and back-propagation.
Parameters:
inputs
and data_samples
two fields.Note: OptimWrapper provides a common interface for updating parameters, please refer to optimizer wrapper documentation in MMEngine for more information.
Returns:
torch.Tensor
]: A dict
of tensor for logging.The val_step
method calls the forward interface of the predict
mode and returns the prediction result, which is further passed to the process interface of the evaluator and the after_val_iter
interface of the Hook.
Parameters:
dict
or tuple
or list
) - Data sampled from the dataset. In MMSegmentation, the data dict contains inputs
and data_samples
two fields.Returns:
list
- The predictions of given data.The BaseModel
implements test_step
the same as val_step
.
The SegDataPreProcessor implemented by MMSegmentation inherits from the BaseDataPreprocessor implemented by MMEngine and provides the functions of data preprocessing and copying data to the target device.
The runner carries the model to the specified device during the construction stage, while the data is carried to the specified device by the SegDataPreProcessor in train_step
, val_step
, and test_step
, and the processed data is further passed to the model.
The parameters of the SegDataPreProcessor
constructor:
The data will be processed as follows:
pad_val
, and pad seg map with defined seg_pad_val
.The parameters of the forward
method:
The returns of the forward
method:
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》