|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- @Author: Yue Wang
- @Contact: yuewangx@mit.edu
- @File: util
- @Time: 4/5/19 3:47 PM
-
- Modified by
- @Author: Yu Deng
- @Contact: dengy02@pcl.ac.cn
- @Time: 2022/7/6 16:30 PM
- """
-
- import os
- import sys
- import errno
- import os.path as osp
-
- import numpy as np
- import tensorflow as tf
-
-
-
- def cal_loss(labels=None, logits=None, smoothing=True):
- ''' Calculate cross entropy loss, apply label smoothing if needed. '''
- labels = tf.reshape(labels, (-1,))
- # print("labels :", labels.shape)
-
- if smoothing:
- eps = 0.2
- n_class = logits.shape[1]
- # one_hot = torch.zeros_like(logits).scatter(1, labels.view(-1, 1), 1)
- one_hot = tf.one_hot(labels, n_class) # dtype=tf.float32
- # print("one_hot", one_hot)
- # print("one_hot shape", one_hot.shape)
-
- one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
-
- # log_prb = F.log_softmax(logits, dim=1)
- logits = tf.convert_to_tensor(logits, dtype=tf.float32)
- # logits = tf.convert_to_tensor(logits)
- log_prb = tf.nn.log_softmax(logits=logits, axis=1)
-
- # loss = -(one_hot * log_prb).sum(dim=1).mean()
- mul_hot_log = tf.math.multiply(one_hot, log_prb)
- mul_hot_log_sum = tf.reduce_sum(input_tensor=mul_hot_log, axis=1, keepdims=False, name=None)
- mul_hot_log_mean = tf.reduce_mean(mul_hot_log_sum)
- loss = -mul_hot_log_mean
- else:
- labels = tf.convert_to_tensor(labels, dtype=tf.int32)
- # labels = tf.convert_to_tensor(labels)
- # reduction=tf.keras.losses.Reduction.SUM ,default: reduction=losses_utils.ReductionV2.AUTO
- cce = tf.keras.losses.SparseCategoricalCrossentropy()
- loss = cce(labels, logits)
-
- return loss
-
-
- class IOStream():
- def __init__(self, path):
- self.f = open(path, 'a')
-
- def cprint(self, text):
- print(text)
- self.f.write(text+'\n')
- self.f.flush()
-
- def close(self):
- self.f.close()
-
-
- def mkdir_if_missing(directory):
- if not osp.exists(directory):
- try:
- os.makedirs(directory)
- except OSError as e:
- if e.errno != errno.EEXIST:
- raise
-
-
- class Logger(object):
- """
- Write console output to external text file.
- Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
- """
- def __init__(self, fpath=None):
- self.console = sys.stdout
- self.file = None
- if fpath is not None:
- mkdir_if_missing(os.path.dirname(fpath))
- self.file = open(fpath, 'w')
-
- def __del__(self):
- self.close()
-
- def __enter__(self):
- pass
-
- def __exit__(self, *args):
- self.close()
-
- def write(self, msg):
- self.console.write(msg)
- if self.file is not None:
- self.file.write(msg)
-
- def flush(self):
- self.console.flush()
- if self.file is not None:
- self.file.flush()
- os.fsync(self.file.fileno())
-
- def close(self):
- self.console.close()
- if self.file is not None:
- self.file.close()
-
- if __name__ == "__main__":
- # [32, 40], [32]
- import numpy as np
- import sklearn.metrics as metrics
- import datetime
-
- logits = np.array([[ 0.4826, 1.8397, 0.1550, -0.2594, 0.3050, 0.8015, -0.2805, 0.0954],
- [ 1.3998, -1.1612, 1.7943, 0.2963, 0.0122, -1.7097, 0.0919, 1.1606],
- [-0.9231, -0.2899, 1.2428, 0.9531, -0.3307, 1.8807, 1.2747, 2.5915],
- [-0.3514, 0.3604, -0.3423, -1.0050, -0.3290, 0.0732, 1.8311, 0.8099]], dtype=np.float32)
- labels = np.array([7, 4, 6, 1], dtype=np.int32)
- logits = tf.convert_to_tensor(logits)
- labels = tf.convert_to_tensor(labels)
-
- print()
- print(logits.shape)
- print(labels.shape)
- print(logits[0])
- print()
-
- ## one_hot like:
- # [[0., 1., 0., 0., 0., 0., 0., 0.],
- # [0., 0., 0., 1., 0., 0., 0., 0.],
- # [0., 1., 0., 0., 0., 0., 0., 0.],
- # [0., 0., 0., 0., 1., 0., 0., 0.]]
- loss = cal_loss(logits=logits, labels=labels, smoothing=True)
- print("loss smoothing", loss)
-
- loss = cal_loss(logits=logits, labels=labels, smoothing=False)
- print("loss", loss)
- print("============")
-
- global_batch_size = 8
- strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
- # dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8]).repeat().batch(global_batch_size)
- dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8]).repeat(8).batch(global_batch_size)
- distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
- print(distributed_iterator)
- print("iter out \n:", next(distributed_iterator))
-
- # x = [np.float32(0.83), np.float32(1.12), np.float32(1.10), np.float32(1.0), np.float32(1.53)]
- # l = [[29, 21, 15, 2, 2, 10, 34, 10],
- # [30, 16, 1, 8, 4, 30, 30, 30],
- # [4, 0, 26, 24, 30, 30, 39, 33],
- # [21, 8, 0, 30, 15, 36, 0, 0],
- # [10, 0, 9, 30, 28, 36, 15, 8]]
- # labels = np.array(l)
- # preds = np.array(x)
- # preds = np.stack(preds)
- # print(labels.shape)
- # print(preds.shape)
- # print()
- # accuracy_score = metrics.accuracy_score(labels, preds)
- # balanced_accuracy_score = metrics.balanced_accuracy_score(labels, preds)
- # print(accuracy_score)
- # print(balanced_accuracy_score)
-
- logits_np = np.array([[ 2.0835e+00, 2.8684e-01, -3.8064e-01, -6.6219e-03, -4.9030e-02,
- -7.0756e-01, -1.6587e-01, 3.7841e-02, 8.4555e-01, -2.8922e-01,
- -7.8332e-01, 1.3700e+00, 1.6362e-01, -5.7401e-01, 1.6290e+00,
- 4.1743e-01, -5.2905e-01, 2.2805e-01, -2.6672e+00, -1.6487e-01,
- 2.8328e-01, -3.1602e-01, 1.8948e+00, 8.3975e-01, 3.1507e-01,
- 1.8589e+00, -5.6998e-01, -9.6480e-01, -5.6747e-01, -3.3958e-01,
- -1.6440e+00, -3.5754e-01, -7.9882e-01, 2.3381e-01, 6.8907e-02,
- -4.4294e-01, 3.3727e-02, -3.4425e-01, -4.4010e-01, -5.5271e-01],
- [-4.0442e-01, -3.7830e-01, 7.4986e-01, -1.9608e-01, -5.2080e-01,
- -2.9841e-01, -1.6777e+00, -1.8840e-01, -8.2758e-01, -5.0043e-01,
- 6.0243e-01, -4.9775e-01, -3.9553e-01, 2.3184e-01, 1.5759e+00,
- -1.3985e+00, 4.8364e-01, 7.5944e-02, -9.1163e-01, -1.0120e+00,
- -1.0840e+00, -1.1333e+00, 7.3563e-01, 7.2417e-01, -3.8258e-01,
- 2.1450e+00, -2.9394e-01, 1.6343e+00, 1.5634e-01, -7.5331e-01,
- 1.8599e+00, 2.9857e-01, 2.1150e-01, 4.3066e-01, -6.0429e-01,
- 9.0129e-01, 2.6864e-01, 2.7161e+00, -1.0155e+00, -8.1855e-01],
- [-2.7429e+00, -2.9181e+00, 3.0671e+00, 1.3629e+00, -1.5668e+00,
- -8.3793e-01, -3.7353e+00, 4.0808e-01, -1.5371e+00, -1.9909e+00,
- 3.9708e-01, -1.5419e+00, -1.9746e+00, -1.3238e-01, -1.5096e+00,
- 4.9266e-01, -1.7744e-01, 1.1247e-01, 6.9266e-01, -8.2123e-01,
- -3.9770e+00, -2.1140e+00, 9.9297e-01, -1.1321e+00, -8.7195e-01,
- 5.5473e+00, 1.0612e+00, 2.4748e+00, 3.0698e-01, -2.0269e+00,
- 1.7185e+00, 5.8352e+00, 3.6988e+00, -9.0987e-01, 8.1961e-01,
- -5.3954e-01, -1.4290e+00, 8.5451e+00, -2.6355e+00, -1.2917e+00],
- [ 1.1162e+01, 6.6738e-02, -1.8976e+00, -4.3443e+00, -1.9987e+00,
- -1.7699e+00, 4.5539e+00, 5.1470e+00, 3.7383e+00, -8.5479e-02,
- -3.3659e+00, 6.0107e+00, 2.9951e-02, -1.3219e+00, 1.0165e+00,
- 1.7946e+00, -1.4292e+00, -1.0467e+00, -3.3188e+00, -2.4000e+00,
- 4.3941e+00, 9.8264e-02, -1.6303e+00, -2.3844e+00, -6.7374e-01,
- -4.3825e+00, 1.5664e+00, -9.1431e-01, -2.6670e+00, -1.7077e+00,
- 1.3887e+00, -2.2931e+00, -2.9852e+00, -3.0248e+00, -2.5534e+00,
- 1.0055e+00, -1.4695e+00, -2.5426e+00, -9.7849e-01, -1.9550e+00],
- [ 3.5921e+00, 8.4306e-01, -1.3569e+00, -2.9547e+00, 4.2032e-01,
- -1.3029e+00, 1.9426e+00, 1.3834e+00, 1.7537e+00, 9.6460e-03,
- -1.5346e+00, 2.0150e+00, 2.5273e-01, -5.5676e-01, 1.9191e+00,
- -1.0179e-01, 2.7503e-01, -1.4862e+00, -4.0593e+00, -1.3711e+00,
- 1.6630e+00, 3.8593e-01, 1.3875e+00, 3.4905e-01, 5.8096e-01,
- -1.4988e+00, 1.0648e+00, -1.1916e+00, -4.3314e-01, -1.2875e+00,
- 8.7663e-01, -7.4405e-01, -1.5142e+00, -2.8483e-02, -2.9628e-01,
- 2.5531e+00, 6.4484e-01, -2.0996e+00, -5.0988e-01, -8.8208e-01],
- [-9.0217e-01, -2.3921e-01, 7.0113e-01, -6.9945e-01, -7.7149e-01,
- -1.5625e+00, -1.4836e+00, 2.3214e-01, 2.5339e-01, -1.5055e+00,
- 1.1614e+00, -9.1424e-01, -6.2381e-01, 4.1322e-02, -2.8845e-01,
- -1.1310e-01, 2.6380e+00, -9.1250e-02, 1.5196e+00, -8.0327e-01,
- -7.2858e-01, -2.9471e+00, 1.9594e+00, 1.6616e+00, -2.4765e-01,
- 3.1741e+00, -6.2680e-01, 4.8777e-01, 2.8064e-01, -1.4813e+00,
- 1.0913e+00, 7.4820e-02, 2.6674e-01, 4.9987e-01, -1.2711e+00,
- 2.7519e+00, 2.3401e+00, 3.1536e+00, -1.6224e+00, -8.7516e-01],
- [-1.6432e+00, -6.1273e-01, 7.9820e-01, 1.0426e+00, -1.2958e+00,
- -4.0188e-01, -1.3746e+00, -8.7540e-01, -7.6528e-01, -1.2126e+00,
- 1.3798e+00, -3.9156e-02, -1.2898e+00, 6.4550e-01, 6.8631e-01,
- -1.5323e+00, 5.7061e-01, 9.7129e-01, -2.0090e+00, -8.3243e-01,
- -1.9216e+00, -9.5328e-01, 1.8287e+00, -7.1166e-01, -7.5810e-01,
- 2.4084e+00, -4.9811e-02, 4.6670e-01, -1.6608e-01, -2.6434e-01,
- 2.1043e+00, 2.0021e+00, 7.8653e-01, 8.5328e-01, -1.0869e-01,
- -7.9721e-01, 2.4434e+00, 6.0053e+00, -1.5854e+00, -4.2325e-01],
- [ 1.3624e+00, 1.2876e+00, -1.1838e+00, -1.1431e+00, 1.5006e-01,
- 4.2363e-01, 2.3242e+00, 1.2580e+00, 1.9115e+00, 6.0025e-01,
- -1.3256e+00, 3.6321e+00, -2.3834e-01, -1.7159e+00, 7.0819e-01,
- 8.8843e-01, -4.8113e-02, -2.3913e+00, -9.3553e-01, -1.6055e+00,
- 2.3139e+00, -5.0882e-01, 5.5813e-01, -6.4976e-01, 5.2835e-01,
- -2.4074e+00, 1.4382e+00, -1.7691e+00, -1.6073e+00, -7.5976e-01,
- 4.4626e-01, -2.1728e+00, -1.8500e+00, -3.3810e-01, 1.0280e+00,
- 8.1578e-01, 3.8114e-01, -9.4520e-01, -2.5134e-01, -5.3149e-01]])
-
- logits_torch = torch.from_numpy(logits_np)
- preds_torch = logits_torch.max(dim=1)[1]
- print("torch preds", preds_torch)
-
- logits_tf = tf.convert_to_tensor(logits_np)
- preds_tf = tf.argmax(logits_tf, axis=1)
- print("tf preds", preds_tf)
-
- # preds1 = logits.max(dim=1)[1]
- # preds2 = logits.argmax(dim=1)
- # print("logits.max(dim=1)", logits.max(dim=1))
- # print("preds1", preds1)
- # print("preds2", preds2)
- # print(preds1.all() == preds2.all())
-
- # logits = tf.convert_to_tensor(np.array([[ 2.0835e+00, 2.8684e-01, -3.8064e-01, -6.6219e-03, -4.9030e-02,
- # -7.0756e-01, -1.6587e-01, 3.7841e-02, 8.4555e-01, -2.8922e-01,
- # -7.8332e-01, 1.3700e+00, 1.6362e-01, -5.7401e-01, 1.6290e+00,
- # 4.1743e-01, -5.2905e-01, 2.2805e-01, -2.6672e+00, -1.6487e-01,
- # 2.8328e-01, -3.1602e-01, 1.8948e+00, 8.3975e-01, 3.1507e-01,
- # 1.8589e+00, -5.6998e-01, -9.6480e-01, -5.6747e-01, -3.3958e-01,
- # -1.6440e+00, -3.5754e-01, -7.9882e-01, 2.3381e-01, 6.8907e-02,
- # -4.4294e-01, 3.3727e-02, -3.4425e-01, -4.4010e-01, -5.5271e-01],
- # [-4.0442e-01, -3.7830e-01, 7.4986e-01, -1.9608e-01, -5.2080e-01,
- # -2.9841e-01, -1.6777e+00, -1.8840e-01, -8.2758e-01, -5.0043e-01,
- # 6.0243e-01, -4.9775e-01, -3.9553e-01, 2.3184e-01, 1.5759e+00,
- # -1.3985e+00, 4.8364e-01, 7.5944e-02, -9.1163e-01, -1.0120e+00,
- # -1.0840e+00, -1.1333e+00, 7.3563e-01, 7.2417e-01, -3.8258e-01,
- # 2.1450e+00, -2.9394e-01, 1.6343e+00, 1.5634e-01, -7.5331e-01,
- # 1.8599e+00, 2.9857e-01, 2.1150e-01, 4.3066e-01, -6.0429e-01,
- # 9.0129e-01, 2.6864e-01, 2.7161e+00, -1.0155e+00, -8.1855e-01],
- # [-2.7429e+00, -2.9181e+00, 3.0671e+00, 1.3629e+00, -1.5668e+00,
- # -8.3793e-01, -3.7353e+00, 4.0808e-01, -1.5371e+00, -1.9909e+00,
- # 3.9708e-01, -1.5419e+00, -1.9746e+00, -1.3238e-01, -1.5096e+00,
- # 4.9266e-01, -1.7744e-01, 1.1247e-01, 6.9266e-01, -8.2123e-01,
- # -3.9770e+00, -2.1140e+00, 9.9297e-01, -1.1321e+00, -8.7195e-01,
- # 5.5473e+00, 1.0612e+00, 2.4748e+00, 3.0698e-01, -2.0269e+00,
- # 1.7185e+00, 5.8352e+00, 3.6988e+00, -9.0987e-01, 8.1961e-01,
- # -5.3954e-01, -1.4290e+00, 8.5451e+00, -2.6355e+00, -1.2917e+00],
- # [ 1.1162e+01, 6.6738e-02, -1.8976e+00, -4.3443e+00, -1.9987e+00,
- # -1.7699e+00, 4.5539e+00, 5.1470e+00, 3.7383e+00, -8.5479e-02,
- # -3.3659e+00, 6.0107e+00, 2.9951e-02, -1.3219e+00, 1.0165e+00,
- # 1.7946e+00, -1.4292e+00, -1.0467e+00, -3.3188e+00, -2.4000e+00,
- # 4.3941e+00, 9.8264e-02, -1.6303e+00, -2.3844e+00, -6.7374e-01,
- # -4.3825e+00, 1.5664e+00, -9.1431e-01, -2.6670e+00, -1.7077e+00,
- # 1.3887e+00, -2.2931e+00, -2.9852e+00, -3.0248e+00, -2.5534e+00,
- # 1.0055e+00, -1.4695e+00, -2.5426e+00, -9.7849e-01, -1.9550e+00],
- # [ 3.5921e+00, 8.4306e-01, -1.3569e+00, -2.9547e+00, 4.2032e-01,
- # -1.3029e+00, 1.9426e+00, 1.3834e+00, 1.7537e+00, 9.6460e-03,
- # -1.5346e+00, 2.0150e+00, 2.5273e-01, -5.5676e-01, 1.9191e+00,
- # -1.0179e-01, 2.7503e-01, -1.4862e+00, -4.0593e+00, -1.3711e+00,
- # 1.6630e+00, 3.8593e-01, 1.3875e+00, 3.4905e-01, 5.8096e-01,
- # -1.4988e+00, 1.0648e+00, -1.1916e+00, -4.3314e-01, -1.2875e+00,
- # 8.7663e-01, -7.4405e-01, -1.5142e+00, -2.8483e-02, -2.9628e-01,
- # 2.5531e+00, 6.4484e-01, -2.0996e+00, -5.0988e-01, -8.8208e-01],
- # [-9.0217e-01, -2.3921e-01, 7.0113e-01, -6.9945e-01, -7.7149e-01,
- # -1.5625e+00, -1.4836e+00, 2.3214e-01, 2.5339e-01, -1.5055e+00,
- # 1.1614e+00, -9.1424e-01, -6.2381e-01, 4.1322e-02, -2.8845e-01,
- # -1.1310e-01, 2.6380e+00, -9.1250e-02, 1.5196e+00, -8.0327e-01,
- # -7.2858e-01, -2.9471e+00, 1.9594e+00, 1.6616e+00, -2.4765e-01,
- # 3.1741e+00, -6.2680e-01, 4.8777e-01, 2.8064e-01, -1.4813e+00,
- # 1.0913e+00, 7.4820e-02, 2.6674e-01, 4.9987e-01, -1.2711e+00,
- # 2.7519e+00, 2.3401e+00, 3.1536e+00, -1.6224e+00, -8.7516e-01],
- # [-1.6432e+00, -6.1273e-01, 7.9820e-01, 1.0426e+00, -1.2958e+00,
- # -4.0188e-01, -1.3746e+00, -8.7540e-01, -7.6528e-01, -1.2126e+00,
- # 1.3798e+00, -3.9156e-02, -1.2898e+00, 6.4550e-01, 6.8631e-01,
- # -1.5323e+00, 5.7061e-01, 9.7129e-01, -2.0090e+00, -8.3243e-01,
- # -1.9216e+00, -9.5328e-01, 1.8287e+00, -7.1166e-01, -7.5810e-01,
- # 2.4084e+00, -4.9811e-02, 4.6670e-01, -1.6608e-01, -2.6434e-01,
- # 2.1043e+00, 2.0021e+00, 7.8653e-01, 8.5328e-01, -1.0869e-01,
- # -7.9721e-01, 2.4434e+00, 6.0053e+00, -1.5854e+00, -4.2325e-01],
- # [ 1.3624e+00, 1.2876e+00, -1.1838e+00, -1.1431e+00, 1.5006e-01,
- # 4.2363e-01, 2.3242e+00, 1.2580e+00, 1.9115e+00, 6.0025e-01,
- # -1.3256e+00, 3.6321e+00, -2.3834e-01, -1.7159e+00, 7.0819e-01,
- # 8.8843e-01, -4.8113e-02, -2.3913e+00, -9.3553e-01, -1.6055e+00,
- # 2.3139e+00, -5.0882e-01, 5.5813e-01, -6.4976e-01, 5.2835e-01,
- # -2.4074e+00, 1.4382e+00, -1.7691e+00, -1.6073e+00, -7.5976e-01,
- # 4.4626e-01, -2.1728e+00, -1.8500e+00, -3.3810e-01, 1.0280e+00,
- # 8.1578e-01, 3.8114e-01, -9.4520e-01, -2.5134e-01, -5.3149e-01]]))
-
- # print()
- # preds = tf.argmax(logits, axis=1)
- # print("preds", preds)
-
-
- # dttime = datetime.datetime(year=2022, month=7, day=15)
- # print(dttime)
-
- # import time
- # t = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime())
- # print(t)
-
- # # 打印本地时间
- # print("SS", time.localtime(time.time())) # 打印本地时间
- # # 打印格式化时间
- # print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) # 打印按指定格式排版的时间
- # FT = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
- # print(FT)
|