Solvers logging (#202)
* Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` * fixing tests --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
d654259428
commit
3f9305d475
@@ -118,8 +118,6 @@ class LabelTensor(torch.Tensor):
|
||||
tensors = [lt.extract(labels) for lt in label_tensors]
|
||||
return LabelTensor(torch.vstack(tensors), labels)
|
||||
|
||||
# TODO remove try/ except thing IMPORTANT
|
||||
# make the label None of default
|
||||
def clone(self, *args, **kwargs):
|
||||
"""
|
||||
Clone the LabelTensor. For more details, see
|
||||
@@ -128,11 +126,12 @@ class LabelTensor(torch.Tensor):
|
||||
:return: a copy of the tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
try:
|
||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
|
||||
out = super().clone(*args, **kwargs)
|
||||
|
||||
# # used before merging
|
||||
# try:
|
||||
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
# except:
|
||||
# out = super().clone(*args, **kwargs)
|
||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
return out
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
@@ -298,4 +297,4 @@ class LabelTensor(torch.Tensor):
|
||||
else:
|
||||
s = 'no labels\n'
|
||||
s += super().__str__()
|
||||
return s
|
||||
return s
|
||||
Reference in New Issue
Block a user