Trainer Tutorial (#271)
* adding tutorial trainer * implementing deepcopy for AbstractProblem and LabelTensor to match Lightning Callbacks --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
""" Module for LabelTensor """
|
||||
|
||||
from typing import Any
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@@ -79,6 +79,21 @@ class LabelTensor(torch.Tensor):
|
||||
)
|
||||
self._labels = labels
|
||||
|
||||
def __deepcopy__(self, __):
|
||||
"""
|
||||
Implements deepcopy for label tensor. By default it stores the
|
||||
current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__`
|
||||
method for creating a new :class:`pina.label_tensor.LabelTensor`.
|
||||
|
||||
:param __: Placeholder parameter.
|
||||
:type __: None
|
||||
:return: The deep copy of the :class:`pina.label_tensor.LabelTensor`.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
labels = self.labels
|
||||
copy_tensor = deepcopy(self.tensor)
|
||||
return LabelTensor(copy_tensor, labels)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""Property decorator for labels
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..utils import merge_tensors, check_consistency
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
@@ -29,6 +30,23 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
Implements deepcopy for the
|
||||
:class:`~pina.problem.abstract_problem.AbstractProblem` class.
|
||||
|
||||
:param dict memo: Memory dictionary, to avoid excess copy
|
||||
:return: The deep copy of the
|
||||
:class:`~pina.problem.abstract_problem.AbstractProblem` class
|
||||
:rtype: AbstractProblem
|
||||
"""
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
for k, v in self.__dict__.items():
|
||||
setattr(result, k, deepcopy(v, memo))
|
||||
return result
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user