#1 upload 'test'

Merged
szw2023 merged 1 commits from szw2023-patch-1 into master 1 year ago
  1. +73
    -0
      test/train.py

+ 73
- 0
test/train.py View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# @Brief: 训练脚本
from tensorflow.keras import optimizers, callbacks, utils, applications
from core.VOCdataset import VOCDataset
from nets.UNet import *
from core.losses import *
from core.metrics import *
from core.callback import *
import core.config as cfg
from evaluate import evaluate
import tensorflow as tf
import os
import cv2 as cv


def train_by_fit(model, epochs, train_gen, test_gen, train_steps, test_steps):
"""
fit方式训练
:param model: 训练模型
:param epochs: 训练轮数
:param train_gen: 训练集生成器
:param test_gen: 测试集生成器
:param train_steps: 训练次数
:param test_steps: 测试次数
:return: None
"""

cbk = [
callbacks.ModelCheckpoint(
'./weights/epoch={epoch:02d}_val_loss={val_loss:.04f}_miou={val_object_miou:.04f}.h5',
save_weights_only=True),
]

learning_rate = CosineAnnealingLRScheduler(epochs, train_steps, 1e-4, 1e-6, warmth_rate=0.1)
optimizer = optimizers.Adam(learning_rate)
lr_info = print_lr(optimizer)

model.compile(optimizer=optimizer,
loss=crossentropy_with_logits,
metrics=[object_accuracy, object_miou, lr_info])

model.fit(train_gen,
steps_per_epoch=train_steps,
validation_data=test_gen,
validation_steps=test_steps,
epochs=epochs,
callbacks=cbk)


if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

if not os.path.exists("weights"):
os.mkdir("weights")

model = UNet(cfg.input_shape, cfg.num_classes)
model.summary()

train_dataset = VOCDataset(cfg.train_txt_path, batch_size=cfg.batch_size, aug=True)
test_dataset = VOCDataset(cfg.val_txt_path, batch_size=cfg.batch_size)

train_steps = len(train_dataset) // cfg.batch_size
test_steps = len(test_dataset) // cfg.batch_size

train_gen = train_dataset.tf_dataset()
test_gen = test_dataset.tf_dataset()

train_by_fit(model, cfg.epochs, train_gen, test_gen, train_steps, test_steps)

Loading…
Cancel
Save