|
- # -*- coding:utf-8 -*-
-
- from __future__ import absolute_import, division, print_function
-
- import json
- import os
- import random, string
- from datetime import datetime, timezone
-
- import cv2
- import numpy as np
- import tensorflow as tf
- import tf_slim as slim
- from sklearn.model_selection import train_test_split
-
- from libs.box_utils import show_box_in_tensor, boxes_utils
- from libs.configs import cfgs, flags
- from libs.io import image_preprocess
- from libs.networks import build_whole_network, build_whole_network_batch
-
- os.environ["CUDA_VISIBLE_DEVICES"] = flags.GPU_GROUP
-
- with open(flags.LABEL_FILE) as f:
- label_json = json.load(f)
- img_id_list = [i["image_id"] for i in label_json["image_info"]]
- img_id_list = np.array(img_id_list)
-
- train_list, test_list = train_test_split(img_id_list, train_size=0.9)
-
-
- def next_img(step, image_list, path=flags.DATASET_DIR):
- if step % len(image_list) == 0:
- np.random.shuffle(image_list)
- img_id = image_list[step % len(image_list)]
- box_and_label_list = []
- for j in label_json["instance_info"]:
- if j["image_id"] == img_id:
- x1, y1, w, h = (
- j["box"][0][0],
- j["box"][0][1],
- j["box"][0][2],
- j["box"][0][3],
- )
- box = [x1, y1, x1 + w, y1 + h] # [xmin, ymin, xmax, ymax]
- label = 1
- box_and_label_list.append(box + [label])
- img_name = [
- i["image_name"] for i in label_json["image_info"] if i["image_id"] == img_id
- ]
- img = cv2.imread(os.path.join(path, img_name[0]))
- box_and_label_list = np.array(box_and_label_list, dtype=np.int32)
- if box_and_label_list.shape[0] == 0:
- return next_img(step + 1, image_list)
- else:
- return img_id, img[:, :, ::-1], box_and_label_list
-
-
- def image_generator(image_list, path=flags.DATASET_DIR):
- for img_id in image_list:
- box_and_label_list = []
- for j in label_json["instance_info"]:
- if j["image_id"] == img_id:
- x1, y1, w, h = (
- j["box"][0][0],
- j["box"][0][1],
- j["box"][0][2],
- j["box"][0][3],
- )
- box = [x1, y1, x1 + w, y1 + h] # [xmin, ymin, xmax, ymax]
- label = 1
- box_and_label_list.append(box + [label])
- img_name = [
- i["image_name"] for i in label_json["image_info"] if i["image_id"] == img_id
- ]
- img = cv2.imread(os.path.join(path, img_name[0]))
- box_and_label_list = np.array(box_and_label_list, dtype=np.int32)
- if box_and_label_list.shape[0] == 0:
- continue
- else:
- yield img_id, img[:, :, ::-1], box_and_label_list
-
-
- def preprocess_img(img_plac, gtbox_plac):
- """
-
- :param img_plac: [H, W, 3] uint 8 img. In RGB.
- :param gtbox_plac: shape of [-1, 5]. [xmin, ymin, xmax, ymax, label]
- :return:
- """
-
- img = tf.cast(img_plac, tf.float32)
-
- # gtboxes_and_label = tf.cast(gtbox_plac, tf.float32)
- img, gtboxes_and_label = image_preprocess.short_side_resize(
- img_tensor=img,
- gtboxes_and_label=gtbox_plac,
- target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN,
- length_limitation=cfgs.IMG_MAX_LENGTH,
- )
- img, gtboxes_and_label = image_preprocess.random_flip_left_right(
- img_tensor=img, gtboxes_and_label=gtboxes_and_label
- )
- if flags.NET_NAME in ["resnet152_v1d", "resnet101_v1d", "resnet50_v1d"]:
- img = img / 255 - tf.constant([[cfgs.PIXEL_MEAN_]])
- else:
- img = img - tf.constant([[cfgs.PIXEL_MEAN]])
- img_batch = tf.expand_dims(img, axis=0)
-
- # gtboxes_and_label = tf.Print(gtboxes_and_label, [tf.shape(gtboxes_and_label)], message='gtbox shape')
- return img_batch, gtboxes_and_label
-
-
- def train():
-
- faster_rcnn = build_whole_network.DetectionNetwork(
- base_network_name=flags.NET_NAME,
- dataset_name=flags.DATASET_NAME,
- is_training=True,
- )
-
- with tf.compat.v1.name_scope("get_batch"):
- img_plac = tf.compat.v1.placeholder(dtype=tf.uint8, shape=[None, None, 3])
- gtbox_plac = tf.compat.v1.placeholder(dtype=tf.int32, shape=[None, 5])
-
- img_batch, gtboxes_and_label = preprocess_img(img_plac, gtbox_plac)
-
- biases_regularizer = tf.compat.v1.no_regularizer
- weights_regularizer = tf.keras.regularizers.l2(0.5 * (cfgs.WEIGHT_DECAY))
-
- # list as many types of layers as possible, even if they are not used now
- with slim.arg_scope(
- [
- slim.conv2d,
- slim.conv2d_in_plane,
- slim.conv2d_transpose,
- slim.separable_conv2d,
- slim.fully_connected,
- ],
- weights_regularizer=weights_regularizer,
- biases_regularizer=biases_regularizer,
- biases_initializer=tf.compat.v1.constant_initializer(0.0),
- ):
- (
- final_bbox,
- final_scores,
- final_category,
- loss_dict,
- ) = faster_rcnn.build_whole_detection_network(
- input_img_batch=img_batch, gtboxes_batch=gtboxes_and_label
- )
-
- # ----------------------------------------------------------------------------------------------------build loss
- weight_decay_loss = tf.add_n(slim.losses.get_regularization_losses())
- cls_loss = loss_dict["cls_loss"]
- reg_loss = loss_dict["reg_loss"]
- total_loss = cls_loss + reg_loss + weight_decay_loss
- iou = boxes_utils.ious_calu(final_bbox, gtboxes_and_label[:, :-1])
-
- # ---------------------------------------------------------------------------------------------------add summary
- tf.compat.v1.summary.scalar("RETINANET_LOSS/cls_loss", cls_loss)
- tf.compat.v1.summary.scalar("RETINANET_LOSS/reg_loss", reg_loss)
-
- tf.compat.v1.summary.scalar("LOSS/total_loss", total_loss)
- tf.compat.v1.summary.scalar("LOSS/regular_weights", weight_decay_loss)
-
- gtboxes_in_img = show_box_in_tensor.draw_boxes_with_categories(
- net_name=flags.NET_NAME,
- dataset_name=flags.DATASET_NAME,
- img_batch=img_batch,
- boxes=gtboxes_and_label[:, :-1],
- labels=gtboxes_and_label[:, -1],
- )
- if cfgs.ADD_BOX_IN_TENSORBOARD:
- detections_in_img = show_box_in_tensor.draw_boxes_with_categories_and_scores(
- net_name=flags.NET_NAME,
- dataset_name=flags.DATASET_NAME,
- img_batch=img_batch,
- boxes=final_bbox,
- labels=final_category,
- scores=final_scores,
- )
- tf.compat.v1.summary.image("Compare/final_detection", detections_in_img)
- tf.compat.v1.summary.image("Compare/gtboxes", gtboxes_in_img)
-
- # ___________________________________________________________________________________________________add summary
-
- global_step = slim.get_or_create_global_step()
- lr = tf.compat.v1.train.piecewise_constant(
- global_step,
- boundaries=[np.int64(cfgs.DECAY_STEP[0]), np.int64(cfgs.DECAY_STEP[1])],
- values=[flags.LR, flags.LR / 10.0, flags.LR / 100.0],
- )
- tf.compat.v1.summary.scalar("lr", lr)
- optimizer = tf.compat.v1.train.MomentumOptimizer(lr, momentum=flags.MOMENTUM)
-
- # ---------------------------------------------------------------------------------------------compute gradients
- gradients = faster_rcnn.get_gradients(optimizer, total_loss)
-
- # enlarge_gradients for bias
- if cfgs.MUTILPY_BIAS_GRADIENT:
- gradients = faster_rcnn.enlarge_gradients_for_bias(gradients)
-
- if cfgs.GRADIENT_CLIPPING_BY_NORM:
- with tf.compat.v1.name_scope("clip_gradients_YJR"):
- gradients = slim.learning.clip_gradient_norms(
- gradients, cfgs.GRADIENT_CLIPPING_BY_NORM
- )
-
- # train_op
- train_op = optimizer.apply_gradients(
- grads_and_vars=gradients, global_step=global_step
- )
- summary_op = tf.compat.v1.summary.merge_all()
- init_op = tf.group(
- tf.compat.v1.global_variables_initializer(),
- tf.compat.v1.local_variables_initializer(),
- )
-
- restorer, restore_ckpt = faster_rcnn.get_restorer(
- flags.TRAINED_CKPT, flags.PRETRAINED_CKPT
- )
- saver = tf.compat.v1.train.Saver(max_to_keep=1, save_relative_paths=True)
-
- config = tf.compat.v1.ConfigProto()
- config.gpu_options.allow_growth = True
-
- with tf.compat.v1.Session(config=config) as sess:
- sess.run(init_op)
- if not restorer is None:
- restorer.restore(sess, restore_ckpt)
- print("restore model")
-
- if not os.path.exists(flags.SUMMARY_PATH):
- os.makedirs(flags.SUMMARY_PATH)
- summary_writer = tf.compat.v1.summary.FileWriter(
- flags.SUMMARY_PATH, graph=sess.graph
- )
-
- for step in range(flags.MAX_ITERATION):
-
- img_id, img, gt_info = next_img(step, train_list)
-
- training_time = datetime.now()
-
- if step % cfgs.SHOW_TRAIN_INFO_INTE != 0 and step % cfgs.SMRY_ITER != 0:
- (
- _,
- global_stepnp,
- _final_bbox,
- _final_scores,
- _final_category,
- ) = sess.run(
- [
- train_op,
- global_step,
- final_bbox,
- final_scores,
- final_category,
- ],
- feed_dict={img_plac: img, gtbox_plac: gt_info},
- )
-
- else:
- if step % cfgs.SHOW_TRAIN_INFO_INTE == 0 and step % cfgs.SMRY_ITER != 0:
- start = datetime.now()
-
- (
- _,
- global_stepnp,
- reg_loss_,
- cls_loss_,
- total_loss_,
- _final_bbox,
- _final_scores,
- _final_category,
- ) = sess.run(
- [
- train_op,
- global_step,
- reg_loss,
- cls_loss,
- total_loss,
- final_bbox,
- final_scores,
- final_category,
- ],
- feed_dict={img_plac: img, gtbox_plac: gt_info},
- )
- end = datetime.now()
- print(
- """ {}: step{} image_name:{} |\t
- reg_loss:{} |\t cls_loss:{} |\t total_loss:{}|per_cost_time:{}s""".format(
- training_time,
- global_stepnp,
- str(img_id),
- reg_loss_,
- cls_loss_,
- total_loss_,
- (end - start),
- )
- )
-
- else:
- if step % cfgs.SMRY_ITER == 0:
- (
- _,
- global_stepnp,
- summary_str,
- _final_bbox,
- _final_scores,
- _final_category,
- ) = sess.run(
- [
- train_op,
- global_step,
- summary_op,
- final_bbox,
- final_scores,
- final_category,
- ],
- feed_dict={img_plac: img, gtbox_plac: gt_info},
- )
- summary_writer.add_summary(summary_str, global_stepnp)
- summary_writer.flush()
-
- if (step > 0 and step % cfgs.SAVE_WEIGHTS_INTE == 0) or (
- step == flags.MAX_ITERATION - 1
- ):
-
- if not os.path.exists(flags.TRAINED_CKPT):
- os.makedirs(flags.TRAINED_CKPT)
-
- save_ckpt = os.path.join(
- flags.TRAINED_CKPT,
- "{}_".format(flags.DATASET_NAME)
- + str(global_stepnp)
- + "model.ckpt",
- )
- saver.save(sess, save_ckpt)
- print(" weights had been saved to {}".format(save_ckpt))
-
- # validate
- if flags.EVALUATE:
- restorer, restore_ckpt = faster_rcnn.get_restorer(
- flags.TRAINED_CKPT, flags.PRETRAINED_CKPT
- )
- restorer.restore(sess, restore_ckpt)
- iou_list = []
- for img_id, img, gt_info in image_generator(test_list):
- (_final_bbox, iou_) = sess.run(
- [
- final_bbox,
- iou,
- ],
- feed_dict={img_plac: img, gtbox_plac: gt_info},
- )
- print("iou: {}".format(iou_))
- if len(iou_) == 0:
- iou_ = [[0]]
- for val in iou_:
- iou_list.extend(val)
- print("iou_list: {}".format(iou_list))
- print(
- "{} mean-iou={}".format(
- datetime.now(timezone.utc).isoformat(), np.mean(iou_list)
- )
- )
-
-
- if __name__ == "__main__":
-
- train()
|