Trainer Tutorial (#271)

* adding tutorial trainer
* implementing deepcopy for AbstractProblem and LabelTensor to match Lightning Callbacks

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
Dario Coscia
2024-04-02 10:22:30 +02:00
committed by GitHub
parent 1d1d767317
commit e0aeb923f3
9 changed files with 1729 additions and 1 deletions

View File

@@ -1,6 +1,6 @@
""" Module for LabelTensor """
from typing import Any
from copy import deepcopy
import torch
from torch import Tensor
@@ -79,6 +79,21 @@ class LabelTensor(torch.Tensor):
)
self._labels = labels
def __deepcopy__(self, __):
"""
Implements deepcopy for label tensor. By default it stores the
current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__`
method for creating a new :class:`pina.label_tensor.LabelTensor`.
:param __: Placeholder parameter.
:type __: None
:return: The deep copy of the :class:`pina.label_tensor.LabelTensor`.
:rtype: LabelTensor
"""
labels = self.labels
copy_tensor = deepcopy(self.tensor)
return LabelTensor(copy_tensor, labels)
@property
def labels(self):
"""Property decorator for labels

View File

@@ -2,6 +2,7 @@
from abc import ABCMeta, abstractmethod
from ..utils import merge_tensors, check_consistency
from copy import deepcopy
import torch
@@ -29,6 +30,23 @@ class AbstractProblem(metaclass=ABCMeta):
# put in self.input_pts all the points that we don't need to sample
self._span_condition_points()
def __deepcopy__(self, memo):
"""
Implements deepcopy for the
:class:`~pina.problem.abstract_problem.AbstractProblem` class.
:param dict memo: Memory dictionary, to avoid excess copy
:return: The deep copy of the
:class:`~pina.problem.abstract_problem.AbstractProblem` class
:rtype: AbstractProblem
"""
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, deepcopy(v, memo))
return result
@property
def input_variables(self):
"""