|
- """
- A modied version of the code by Tae Hwan Jung
- https://github.com/graykode/modelsummary
- """
-
- import torch
- import numpy as np
- import torch.nn as nn
- from collections import OrderedDict
-
- def summary(model, input_shape, batch_size=-1, intputshow=True):
-
- def register_hook(module):
- def hook(module, input, output=None):
- class_name = str(module.__class__).split(".")[-1].split("'")[0]
- module_idx = len(summary)
-
- m_key = "%s-%i" % (class_name, module_idx + 1)
- summary[m_key] = OrderedDict()
- summary[m_key]["input_shape"] = list(input[0].size())
- summary[m_key]["input_shape"][0] = batch_size
-
- params = 0
- if hasattr(module, "weight") and hasattr(module.weight, "size"):
- params += torch.prod(torch.LongTensor(list(module.weight.size())))
- summary[m_key]["trainable"] = module.weight.requires_grad
- if hasattr(module, "bias") and hasattr(module.bias, "size"):
- params += torch.prod(torch.LongTensor(list(module.bias.size())))
- summary[m_key]["nb_params"] = params
-
- if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList)
- and not (module == model)) and 'torch' in str(module.__class__):
- if intputshow is True:
- hooks.append(module.register_forward_pre_hook(hook))
- else:
- hooks.append(module.register_forward_hook(hook))
-
- # create properties
- summary = OrderedDict()
- hooks = []
-
- # register hook
- model.apply(register_hook)
- model(torch.zeros(input_shape))
-
- # remove these hooks
- for h in hooks:
- h.remove()
-
- model_info = ''
-
- model_info += "-----------------------------------------------------------------------\n"
- line_new = "{:>25} {:>25} {:>15}".format("Layer (type)", "Input Shape", "Param #")
- model_info += line_new + '\n'
- model_info += "=======================================================================\n"
-
- total_params = 0
- total_output = 0
- trainable_params = 0
- for layer in summary:
- line_new = "{:>25} {:>25} {:>15}".format(
- layer,
- str(summary[layer]["input_shape"]),
- "{0:,}".format(summary[layer]["nb_params"]),
- )
-
- total_params += summary[layer]["nb_params"]
- if intputshow is True:
- total_output += np.prod(summary[layer]["input_shape"])
- else:
- total_output += np.prod(summary[layer]["output_shape"])
- if "trainable" in summary[layer]:
- if summary[layer]["trainable"] == True:
- trainable_params += summary[layer]["nb_params"]
-
- model_info += line_new + '\n'
-
- model_info += "=======================================================================\n"
- model_info += "Total params: {0:,}\n".format(total_params)
- model_info += "Trainable params: {0:,}\n".format(trainable_params)
- model_info += "Non-trainable params: {0:,}\n".format(total_params - trainable_params)
- model_info += "-----------------------------------------------------------------------\n"
-
- return model_info
|