|
|
@@ -0,0 +1,73 @@ |
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
# @Brief: 训练脚本 |
|
|
|
from tensorflow.keras import optimizers, callbacks, utils, applications |
|
|
|
from core.VOCdataset import VOCDataset |
|
|
|
from nets.UNet import * |
|
|
|
from core.losses import * |
|
|
|
from core.metrics import * |
|
|
|
from core.callback import * |
|
|
|
import core.config as cfg |
|
|
|
from evaluate import evaluate |
|
|
|
import tensorflow as tf |
|
|
|
import os |
|
|
|
import cv2 as cv |
|
|
|
|
|
|
|
|
|
|
|
def train_by_fit(model, epochs, train_gen, test_gen, train_steps, test_steps): |
|
|
|
""" |
|
|
|
fit方式训练 |
|
|
|
:param model: 训练模型 |
|
|
|
:param epochs: 训练轮数 |
|
|
|
:param train_gen: 训练集生成器 |
|
|
|
:param test_gen: 测试集生成器 |
|
|
|
:param train_steps: 训练次数 |
|
|
|
:param test_steps: 测试次数 |
|
|
|
:return: None |
|
|
|
""" |
|
|
|
|
|
|
|
cbk = [ |
|
|
|
callbacks.ModelCheckpoint( |
|
|
|
'./weights/epoch={epoch:02d}_val_loss={val_loss:.04f}_miou={val_object_miou:.04f}.h5', |
|
|
|
save_weights_only=True), |
|
|
|
] |
|
|
|
|
|
|
|
learning_rate = CosineAnnealingLRScheduler(epochs, train_steps, 1e-4, 1e-6, warmth_rate=0.1) |
|
|
|
optimizer = optimizers.Adam(learning_rate) |
|
|
|
lr_info = print_lr(optimizer) |
|
|
|
|
|
|
|
model.compile(optimizer=optimizer, |
|
|
|
loss=crossentropy_with_logits, |
|
|
|
metrics=[object_accuracy, object_miou, lr_info]) |
|
|
|
|
|
|
|
model.fit(train_gen, |
|
|
|
steps_per_epoch=train_steps, |
|
|
|
validation_data=test_gen, |
|
|
|
validation_steps=test_steps, |
|
|
|
epochs=epochs, |
|
|
|
callbacks=cbk) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
|
|
|
|
gpus = tf.config.experimental.list_physical_devices("GPU") |
|
|
|
if gpus: |
|
|
|
for gpu in gpus: |
|
|
|
tf.config.experimental.set_memory_growth(gpu, True) |
|
|
|
|
|
|
|
if not os.path.exists("weights"): |
|
|
|
os.mkdir("weights") |
|
|
|
|
|
|
|
model = UNet(cfg.input_shape, cfg.num_classes) |
|
|
|
model.summary() |
|
|
|
|
|
|
|
train_dataset = VOCDataset(cfg.train_txt_path, batch_size=cfg.batch_size, aug=True) |
|
|
|
test_dataset = VOCDataset(cfg.val_txt_path, batch_size=cfg.batch_size) |
|
|
|
|
|
|
|
train_steps = len(train_dataset) // cfg.batch_size |
|
|
|
test_steps = len(test_dataset) // cfg.batch_size |
|
|
|
|
|
|
|
train_gen = train_dataset.tf_dataset() |
|
|
|
test_gen = test_dataset.tf_dataset() |
|
|
|
|
|
|
|
train_by_fit(model, cfg.epochs, train_gen, test_gen, train_steps, test_steps) |