From 18e5d235bc43b006b0a51d2f1621371d6aec57c4 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Mon, 28 Nov 2022 10:21:16 +0100 Subject: [PATCH] create utils.py, method for parameters count --- pina/utils.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 pina/utils.py diff --git a/pina/utils.py b/pina/utils.py new file mode 100644 index 0000000..769ee97 --- /dev/null +++ b/pina/utils.py @@ -0,0 +1,27 @@ +"""Utils module""" + +def number_parameters(model, aggregate=True, only_trainable=True): #TODO: check + """ + Return the number of parameters of a given `model`. + + :param torch.nn.Module model: the torch module to inspect. + :param bool aggregate: if True the return values is an integer corresponding + to the total amount of parameters of whole model. If False, it returns a + dictionary whose keys are the names of layers and the values the + corresponding number of parameters. Default is True. + :param bool trainable: if True, only trainable parameters are count, + otherwise no. Default is True. + :return: the number of parameters of the model + :rtype: dict or int + """ + tmp = {} + for name, parameter in model.named_parameters(): + if only_trainable and not parameter.requires_grad: + continue + + tmp[name] = parameter.numel() + + if aggregate: + tmp = sum(tmp.values()) + + return tmp