Solvers logging (#202)

* Modifying solvers to log every epoch correctly
* add `on_epoch` flag to logger
* fix bug in `pinn.py` `pts -> samples` in `_loss_phys`
* add `optimizer_zero_grad()` in garom generator training loop
* modify imports in `callbacks.py`
* fixing tests

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:10:23 +01:00
committed by Nicola Demo
parent d654259428
commit 3f9305d475
5 changed files with 28 additions and 27 deletions

View File

@@ -118,8 +118,6 @@ class LabelTensor(torch.Tensor):
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)
# TODO remove try/ except thing IMPORTANT
# make the label None of default
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
@@ -128,11 +126,12 @@ class LabelTensor(torch.Tensor):
:return: a copy of the tensor
:rtype: LabelTensor
"""
try:
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs)
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out
def to(self, *args, **kwargs):
@@ -298,4 +297,4 @@ class LabelTensor(torch.Tensor):
else:
s = 'no labels\n'
s += super().__str__()
return s
return s