|
- #============ 实现模型混合量化,并可以从输出中查看模型byte大小 ==========
- # import pathlib
- # tflite_models_dir = pathlib.Path("/model/")
- # tflite_models_dir.mkdir(exist_ok=True, parents=True)
- # model = tf.keras.models.load_model('/pretrainmodel/my.h5')
- # converter = tf.lite.TFLiteConverter.from_keras_model(model)
-
- # # # 查看未量化前模型的byte
- # # resnet_tflite_file = tflite_models_dir/"a.tflite"
- # # resnet_tflite_file.write_bytes(converter.convert())
-
- # # 查看对模型量化操作,转为tflite后的byte
- # converter.target_spec.supported_ops = [
- # tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
- # tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
- # ]
- # converter.optimizations = [tf.lite.Optimize.DEFAULT]
- # quantized_tflite_file = tflite_models_dir/"a_quantized.tflite"
- # quantized_tflite_file.write_bytes(converter.convert())
- #======================================================
-
- # ==================== 实现对量化模型的inference,建议使用tf 2.11.0 之后的版本 =======================
- import h5py
- import os
- import sys
- import numpy as np
- import pandas as pd
- import tensorflow as tf
- from sklearn.metrics import log_loss
- from sklearn.metrics import confusion_matrix
-
- import matplotlib.pyplot as plt
- from sklearn.metrics import confusion_matrix
- import numpy as np
-
- def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues, png_name = None):
- if not title:
- if normalize:
- title = 'Normalized confusion matrix'
- else:
- title = 'Confusion matrix, without normalization'
-
- cm = confusion_matrix(y_true, y_pred)
- if normalize:
- cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
- else:
- pass
-
- fig, ax = plt.subplots(figsize=(10, 10))
- for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
- ax.get_xticklabels() + ax.get_yticklabels()):
- item.set_fontsize(20)
- im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
- ax.figure.colorbar(im, ax=ax,fraction=0.046, pad=0.04)
- ax.set(xticks=np.arange(cm.shape[1]),
- yticks=np.arange(cm.shape[0]),
- xticklabels=classes,
- yticklabels=classes,
- title=title,
- ylabel='True label',
- xlabel='Predicted label')
-
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
- rotation_mode="anchor")
-
- fmt = '.2f' if normalize else 'd'
- thresh = cm.max() / 2.
- for i in range(cm.shape[0]):
- for j in range(cm.shape[1]):
- ax.text(j, i, format(cm[i, j], fmt),
- ha="center", va="center",
- color="white" if cm[i, j] > thresh else "black")
-
- fig.tight_layout()
- plt.savefig(png_name)
- return
-
- #=========================================================================================================#
- def evaluate_model(interpreter, test_images, test_labels, num_class, is_eval=False):
- input_index = interpreter.get_input_details()[0]["index"]
- output_index = interpreter.get_output_details()[0]["index"]
-
- prediction_digits = []
- pred_output_all = np.empty([1, num_class])
- i = 0
- for test_image in test_images:
- i = i+1
- if (i % 100 == 0):
- print(i)
- test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
- interpreter.set_tensor(input_index, test_image)
-
- interpreter.invoke()
-
- output = interpreter.get_tensor(output_index)
- pred_output = output[0]
- pred_output.reshape([1, num_class])
- pred_output_all = np.vstack((pred_output_all, pred_output))
- digit = np.argmax(output[0])
- prediction_digits.append(digit)
-
- pred_output_all = pred_output_all[1:,:]
-
- if is_eval:
- return pred_output_all, prediction_digits
- else:
- accurate_count = 0
- for index in range(len(prediction_digits)):
- if prediction_digits[index] == test_labels[index]:
- accurate_count += 1
- accuracy = accurate_count * 1.0 / len(prediction_digits)
- return accuracy, pred_output_all, prediction_digits
-
- #=========================================================================================================#
- is_eval = False
-
- if not is_eval:
- data_path = '/dataset/h5_origin'
- val_csv = data_path + '/fold1_evaluate.csv'
- feat_path = data_path + '/feature_origin.h5'
- model_path = '/model/1.tflite'
- csv_path = '/model/picture.csv'
-
- num_freq_bin = 128
- num_classes = 10
-
- #=========================================================================================================#
- # 载入特征
- def load_hdf5(hdf5_path):
- '''Load hdf5 file.
-
- Returns:
- data_dict: dict of data, e.g.:
- {'audio_name': np.array(['a.wav', 'b.wav', ...]),
- 'feature': (audios_num, frames_num, mel_bins)
- 'target': (audios_num,),
- ...}
- '''
- data_dict = {}
-
- with h5py.File(hdf5_path, 'r') as hf:
- data_dict['y_train'] = hf['y_train'][:]
- data_dict['y_val'] = hf['y_val'][:]
- data_dict['data_train'] = hf['data_train'][:].astype(np.float32)
- data_dict['data_val'] = hf['data_val'][:].astype(np.float32)
-
- return data_dict
-
- #=========================================================================================================#
-
- #=========================================================================================================#
- if not is_eval:
- feature = load_hdf5(feat_path)
- data_val = feature['data_val']
- y_val = feature['y_val']
- y_val_onehot = tf.keras.utils.to_categorical(y_val, num_classes)
-
- dev_test_df = pd.read_csv(val_csv, sep='\t', encoding='ASCII')
- wav_paths = dev_test_df['filename'].tolist()
-
- device_idxs = []
- for idx, elem in enumerate(wav_paths):
- wav_paths[idx] = wav_paths[idx].split('/')[-1]
- device_idxs.append(wav_paths[idx].split('.')[0].split('-')[-1])
- device_list = np.unique(device_idxs)
-
- class_idxs = []
- for idx, elem in enumerate(wav_paths):
- wav_paths[idx] = wav_paths[idx].split('/')[-1]
- class_idxs.append(wav_paths[idx].split('.')[0].split('-')[0])
- class_list = np.unique(class_idxs)
-
- #=========================================================================================================#
- interpreter_quant = tf.lite.Interpreter(model_path=model_path)
- interpreter_quant.allocate_tensors()
-
- #=========================================================================================================#
-
- if not is_eval:
- overall_acc, preds, preds_class_idx = evaluate_model(interpreter_quant,
- data_val,
- y_val,
- num_class=num_classes,
- is_eval=False)
-
- over_loss = log_loss(y_val_onehot, preds)
- print("\n\nVal acc: ", "{0:.4f}".format(overall_acc))
- print("Val log loss: ", "{0:.4f}".format(over_loss))
-
- device_acc = []
- device_loss = []
- for device_id in device_list:
- cur_preds = np.array([preds[i] for i in range(len(device_idxs)) if device_idxs[i] == device_id])
- cur_y_pred_val = np.argmax(cur_preds,axis=1)
- cur_y_val_onehot = np.array([y_val_onehot[i] for i in range(len(device_idxs)) if device_idxs[i] == device_id])
- cur_y_val = [y_val[i] for i in range(len(device_idxs)) if device_idxs[i] == device_id]
- cur_loss = log_loss(cur_y_val_onehot, cur_preds)
- cur_acc = np.sum(cur_y_pred_val==cur_y_val) / len(cur_preds)
-
- device_acc.append(cur_acc)
- device_loss.append(cur_loss)
-
- print("\n\nDevices list: ", device_list)
- print("Per-device val acc : ", np.array(device_acc))
- print("Per-device val loss : ", np.array(device_loss))
-
- # get confusion matrix
- y_pred_val = np.argmax(preds, axis=1)
- conf_matrix = confusion_matrix(y_val, y_pred_val)
- plot_confusion_matrix(y_val, y_pred_val, class_list, normalize=True, title=None, png_name=csv_path.replace('.csv','.png'))
- print("\n\nConfusion matrix:")
- print(conf_matrix)
-
- class_acc = []
- class_loss = []
- for class_id in class_list:
- cur_preds = np.array([preds[i] for i in range(len(class_idxs)) if class_idxs[i] == class_id])
- cur_y_pred_val = np.argmax(cur_preds, axis=1)
- cur_y_val_onehot = np.array([y_val_onehot[i] for i in range(len(class_idxs)) if class_idxs[i] == class_id])
- cur_y_val = [y_val[i] for i in range(len(class_idxs)) if class_idxs[i] == class_id]
- cur_loss = log_loss(cur_y_val_onehot, cur_preds)
- cur_acc = np.sum(cur_y_pred_val==cur_y_val) / len(cur_preds)
-
- class_acc.append(cur_acc)
- class_loss.append(cur_loss)
-
- print("\n\nclasses list: ", class_list)
- print("Per-class val acc : ", np.array(class_acc))
- print("Per-class val loss : ", np.array(class_loss))
-
- #=========================================================================================================#
- scene_map_str = """
- airport 0
- bus 1
- metro 2
- metro_station 3
- park 4
- public_square 5
- shopping_mall 6
- street_pedestrian 7
- street_traffic 8
- tram 9
- """
-
- scene_index_map={}
- for line in scene_map_str.strip().split('\n'):
- ch, index = line.split()
- scene_index_map[int(index)] = ch
- labels = [str(scene_index_map[c]) for c in y_pred_val]
- filename = [str(a[:]) for a in wav_paths]
- left = {'filename': filename, 'scene_label': labels}
- left_df = pd.DataFrame(left)
- right_df = pd.DataFrame(preds, columns = ['airport',
- 'bus',
- 'metro',
- 'metro_station',
- 'park',
- 'public_square',
- 'shopping_mall',
- 'street_pedestrian',
- 'street_traffic',
- 'tram'] )
- merge = pd.concat([left_df, right_df], axis=1, sort=False)
- merge.to_csv(csv_path, sep = '\t', index=False)
|