supervised working

This commit is contained in:
Nicola Demo
2024-08-08 16:19:52 +02:00
parent 5245a0b68c
commit 9d9c2aa23e
61 changed files with 375 additions and 262 deletions

View File

@@ -15,11 +15,8 @@ from .label_tensor import LabelTensor
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter
from .condition import Condition
from .dataset import SamplePointDataset
from .dataset import SamplePointLoader
from .optimizer import TorchOptimizer
from .scheduler import TorchScheduler
from .condition.condition import Condition
from .data.dataset import SamplePointDataset
from .data.dataset import SamplePointLoader
from .data import SamplePointDataset
from .data import SamplePointLoader

View File

@@ -1,10 +1,10 @@
__all__ = [
'Condition',
'ConditionInterface',
'InputOutputCondition',
'InputEquationCondition'
'LocationEquationCondition',
'DomainOutputCondition',
'DomainEquationCondition'
]
from .condition_interface import ConditionInterface
from .input_output_condition import InputOutputCondition
from .domain_output_condition import DomainOutputCondition
from .domain_equation_condition import DomainEquationCondition

View File

@@ -1,9 +1,11 @@
""" Condition module. """
from ..label_tensor import LabelTensor
from ..geometry import Location
from ..domain import DomainInterface
from ..equation.equation import Equation
from . import DomainOutputCondition, DomainEquationCondition
def dummy(a):
"""Dummy function for testing purposes."""
@@ -51,14 +53,6 @@ class Condition:
"""
__slots__ = [
"input_points",
"output_points",
"location",
"equation",
"data_weight",
]
# def _dictvalue_isinstance(self, dict_, key_, class_):
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
# if key_ not in dict_.keys():
@@ -77,11 +71,17 @@ class Condition:
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
# )
from . import InputOutputCondition
def __new__(cls, *args, **kwargs):
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
return InputOutputCondition(**kwargs)
return DomainOutputCondition(
domain=kwargs["input_points"],
output_points=kwargs["output_points"]
)
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
return DomainOutputCondition(**kwargs)
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
return DomainEquationCondition(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

View File

@@ -4,6 +4,9 @@ from abc import ABCMeta, abstractmethod
class ConditionInterface(metaclass=ABCMeta):
def __init__(self) -> None:
self._problem = None
@abstractmethod
def residual(self, model):
"""

View File

@@ -1,26 +1,34 @@
from . import ConditionInterface
class InputOutputCondition(ConditionInterface):
class DomainOutputCondition(ConditionInterface):
"""
Condition for input/output data.
"""
__slots__ = ["input_points", "output_points"]
__slots__ = ["domain", "output_points"]
def __init__(self, input_points, output_points):
def __init__(self, domain, output_points):
"""
Constructor for the `InputOutputCondition` class.
"""
super().__init__()
self.input_points = input_points
print(self)
self.domain = domain
self.output_points = output_points
@property
def input_points(self):
"""
Get the input points of the condition.
"""
return self._problem.domains[self.domain]
def residual(self, model):
"""
Compute the residual of the condition.
"""
return self.batch_residual(model, self.input_points, self.output_points)
return self.batch_residual(model, self.domain, self.output_points)
@staticmethod
def batch_residual(model, input_points, output_points):

View File

@@ -1,7 +1,7 @@
from . import ConditionInterface
class InputOutputCondition(ConditionInterface):
class InputEquationCondition(ConditionInterface):
"""
Condition for input/output data.
"""

View File

@@ -0,0 +1,7 @@
__all__ = [
]
from .pina_dataloader import SamplePointLoader
from .data_dataset import DataPointDataset
from .sample_dataset import SamplePointDataset
from .pina_batch import Batch

41
pina/data/data_dataset.py Normal file
View File

@@ -0,0 +1,41 @@
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
class DataPointDataset(Dataset):
def __init__(self, problem, device) -> None:
super().__init__()
input_list = []
output_list = []
self.condition_names = []
for name, condition in problem.conditions.items():
if hasattr(condition, "output_points"):
input_list.append(problem.conditions[name].input_points)
output_list.append(problem.conditions[name].output_points)
self.condition_names.append(name)
self.input_pts = LabelTensor.stack(input_list)
self.output_pts = LabelTensor.stack(output_list)
if self.input_pts != []:
self.condition_indeces = torch.cat(
[
torch.tensor([i] * len(input_list[i]))
for i in range(len(self.condition_names))
],
dim=0,
)
else: # if there are no data points
self.condition_indeces = torch.tensor([])
self.input_pts = torch.tensor([])
self.output_pts = torch.tensor([])
self.input_pts = self.input_pts.to(device)
self.output_pts = self.output_pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.input_pts.shape[0]

36
pina/data/pina_batch.py Normal file
View File

@@ -0,0 +1,36 @@
class Batch:
"""
This class is used to create a dataset of sample points.
"""
def __init__(self, type_, idx, *args, **kwargs) -> None:
"""
"""
if type_ == "sample":
if len(args) != 2:
raise RuntimeError
input = args[0]
conditions = args[1]
self.input = input[idx]
self.condition = conditions[idx]
elif type_ == "data":
if len(args) != 3:
raise RuntimeError
input = args[0]
output = args[1]
conditions = args[2]
self.input = input[idx]
self.output = output[idx]
self.condition = conditions[idx]
else:
raise ValueError("Invalid number of arguments.")

View File

@@ -1,84 +1,8 @@
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
class SamplePointDataset(Dataset):
"""
This class is used to create a dataset of sample points.
"""
def __init__(self, problem, device) -> None:
"""
:param dict input_pts: The input points.
"""
super().__init__()
pts_list = []
self.condition_names = []
for name, condition in problem.conditions.items():
if not hasattr(condition, "output_points"):
pts_list.append(problem.input_pts[name])
self.condition_names.append(name)
self.pts = LabelTensor.vstack(pts_list)
if self.pts != []:
self.condition_indeces = torch.cat(
[
torch.tensor([i] * len(pts_list[i]))
for i in range(len(self.condition_names))
],
dim=0,
)
else: # if there are no sample points
self.condition_indeces = torch.tensor([])
self.pts = torch.tensor([])
self.pts = self.pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.pts.shape[0]
class DataPointDataset(Dataset):
def __init__(self, problem, device) -> None:
super().__init__()
input_list = []
output_list = []
self.condition_names = []
for name, condition in problem.conditions.items():
if hasattr(condition, "output_points"):
input_list.append(problem.conditions[name].input_points)
output_list.append(problem.conditions[name].output_points)
self.condition_names.append(name)
self.input_pts = LabelTensor.vstack(input_list)
self.output_pts = LabelTensor.vstack(output_list)
if self.input_pts != []:
self.condition_indeces = torch.cat(
[
torch.tensor([i] * len(input_list[i]))
for i in range(len(self.condition_names))
],
dim=0,
)
else: # if there are no data points
self.condition_indeces = torch.tensor([])
self.input_pts = torch.tensor([])
self.output_pts = torch.tensor([])
self.input_pts = self.input_pts.to(device)
self.output_pts = self.output_pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.input_pts.shape[0]
from .sample_dataset import SamplePointDataset
from .data_dataset import DataPointDataset
from .pina_batch import Batch
class SamplePointLoader:
"""
@@ -133,6 +57,8 @@ class SamplePointLoader:
else:
self.random_idx = torch.arange(len(self.batch_list))
self._prepare_batches()
def _prepare_data_dataset(self, dataset, batch_size, shuffle):
"""
Prepare the dataset for data points.
@@ -169,7 +95,7 @@ class SamplePointLoader:
self.batch_output_pts = torch.tensor_split(
dataset.output_pts, batch_num
)
print(input_labels)
for i in range(len(self.batch_input_pts)):
self.batch_input_pts[i].labels = input_labels
self.batch_output_pts[i].labels = output_labels
@@ -216,6 +142,29 @@ class SamplePointLoader:
self.tensor_conditions, batch_num
)
def _prepare_batches(self):
"""
Prepare the batches.
"""
self.batches = []
for i in range(len(self.batch_list)):
type_, idx_ = self.batch_list[i]
if type_ == "sample":
batch = Batch(
"sample", idx_,
self.batch_sample_pts,
self.batch_sample_conditions)
else:
batch = Batch(
"data", idx_,
self.batch_input_pts,
self.batch_output_pts,
self.batch_data_conditions)
print(batch.input.labels)
self.batches.append(batch)
def __iter__(self):
"""
Return an iterator over the points. Any element of the iterator is a
@@ -233,21 +182,24 @@ class SamplePointLoader:
:rtype: iter
"""
# for i in self.random_idx:
for i in range(len(self.batch_list)):
type_, idx_ = self.batch_list[i]
for i in self.random_idx:
yield self.batches[i]
if type_ == "sample":
d = {
"pts": self.batch_sample_pts[idx_].requires_grad_(True),
"condition": self.batch_sample_conditions[idx_],
}
else:
d = {
"pts": self.batch_input_pts[idx_].requires_grad_(True),
"output": self.batch_output_pts[idx_],
"condition": self.batch_data_conditions[idx_],
}
yield d
# for i in range(len(self.batch_list)):
# type_, idx_ = self.batch_list[i]
# if type_ == "sample":
# d = {
# "pts": self.batch_sample_pts[idx_].requires_grad_(True),
# "condition": self.batch_sample_conditions[idx_],
# }
# else:
# d = {
# "pts": self.batch_input_pts[idx_].requires_grad_(True),
# "output": self.batch_output_pts[idx_],
# "condition": self.batch_data_conditions[idx_],
# }
# yield d
def __len__(self):
"""

View File

@@ -0,0 +1,43 @@
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
class SamplePointDataset(Dataset):
"""
This class is used to create a dataset of sample points.
"""
def __init__(self, problem, device) -> None:
"""
:param dict input_pts: The input points.
"""
super().__init__()
pts_list = []
self.condition_names = []
for name, condition in problem.conditions.items():
if not hasattr(condition, "output_points"):
pts_list.append(problem.input_pts[name])
self.condition_names.append(name)
self.pts = LabelTensor.stack(pts_list)
if self.pts != []:
self.condition_indeces = torch.cat(
[
torch.tensor([i] * len(pts_list[i]))
for i in range(len(self.condition_names))
],
dim=0,
)
else: # if there are no sample points
self.condition_indeces = torch.tensor([])
self.pts = torch.tensor([])
self.pts = self.pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.pts.shape[0]

View File

@@ -1,5 +1,5 @@
__all__ = [
"Location",
"DomainInterface",
"CartesianDomain",
"EllipsoidDomain",
"Union",
@@ -10,7 +10,7 @@ __all__ = [
"SimplexDomain",
]
from .location import Location
from .domain_interface import DomainInterface
from .cartesian import CartesianDomain
from .ellipsoid import EllipsoidDomain
from .exclusion_domain import Exclusion

View File

@@ -1,11 +1,11 @@
import torch
from .location import Location
from .domain_interface import DomainInterface
from ..label_tensor import LabelTensor
from ..utils import torch_lhs, chebyshev_roots
class CartesianDomain(Location):
class CartesianDomain(DomainInterface):
"""PINA implementation of Hypercube domain."""
def __init__(self, cartesian_dict):

View File

@@ -1,9 +1,9 @@
"""Module for Location class."""
"""Module for the DomainInterface class."""
from abc import ABCMeta, abstractmethod
class Location(metaclass=ABCMeta):
class DomainInterface(metaclass=ABCMeta):
"""
Abstract Location class.
Any geometry entity should inherit from this class.

View File

@@ -1,11 +1,11 @@
import torch
from .location import Location
from .domain_interface import DomainInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
class EllipsoidDomain(Location):
class EllipsoidDomain(DomainInterface):
"""PINA implementation of Ellipsoid domain."""
def __init__(self, ellipsoid_dict, sample_surface=False):

View File

@@ -1,11 +1,11 @@
""" Module for OperationInterface class. """
from .location import Location
from .domain_interface import DomainInterface
from ..utils import check_consistency
from abc import ABCMeta, abstractmethod
class OperationInterface(Location, metaclass=ABCMeta):
class OperationInterface(DomainInterface, metaclass=ABCMeta):
def __init__(self, geometries):
"""
@@ -15,7 +15,7 @@ class OperationInterface(Location, metaclass=ABCMeta):
such as ``EllipsoidDomain`` or ``CartesianDomain``.
"""
# check consistency geometries
check_consistency(geometries, Location)
check_consistency(geometries, DomainInterface)
# check we are passing always different
# geometries with the same labels.

View File

@@ -1,11 +1,11 @@
import torch
from .location import Location
from pina.geometry import CartesianDomain
from .domain_interface import DomainInterface
from pina.domain import CartesianDomain
from ..label_tensor import LabelTensor
from ..utils import check_consistency
class SimplexDomain(Location):
class SimplexDomain(DomainInterface):
"""PINA implementation of a Simplex."""
def __init__(self, simplex_matrix, sample_surface=False):

View File

@@ -229,10 +229,6 @@ from torch import Tensor
# detached._labels = self._labels
# return detached
# def requires_grad_(self, mode=True):
# lt = super().requires_grad_(mode)
# lt.labels = self.labels
# return lt
# def append(self, lt, mode="std"):
# """
@@ -406,11 +402,29 @@ class LabelTensor(torch.Tensor):
return LabelTensor(new_tensor, label_to_extract)
def __str__(self):
s = ''
for key, value in self.labels.items():
s += f"{key}: {value}\n"
s += '\n'
s += super().__str__()
return s
return s
@staticmethod
def stack(tensors):
"""
"""
if len(tensors) == 0:
return []
if len(tensors) == 1:
return tensors[0]
raise NotImplementedError
labels = [tensor.labels for tensor in tensors]
print(labels)
def requires_grad_(self, mode=True):
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt

View File

@@ -5,6 +5,8 @@ from ..utils import merge_tensors, check_consistency
from copy import deepcopy
import torch
from .. import LabelTensor
class AbstractProblem(metaclass=ABCMeta):
"""
@@ -18,17 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
# variable storing all points
self.input_pts = {}
# varible to check if sampling is done. If no location
# element is presented in Condition this variable is set to true
self._have_sampled_points = {}
self._discretized_domains = {}
for name, domain in self.domains.items():
if isinstance(domain, (torch.Tensor, LabelTensor)):
self._discretized_domains[name] = domain
for condition_name in self.conditions:
self._have_sampled_points[condition_name] = False
self.conditions[condition_name]._problem = self
# # variable storing all points
# self.input_pts = {}
# put in self.input_pts all the points that we don't need to sample
self._span_condition_points()
# # varible to check if sampling is done. If no location
# # element is presented in Condition this variable is set to true
# self._have_sampled_points = {}
# for condition_name in self.conditions:
# self._have_sampled_points[condition_name] = False
# # put in self.input_pts all the points that we don't need to sample
# self._span_condition_points()
def __deepcopy__(self, memo):
"""
@@ -63,15 +74,20 @@ class AbstractProblem(metaclass=ABCMeta):
variables += self.spatial_variables
if hasattr(self, "temporal_variable"):
variables += self.temporal_variable
if hasattr(self, "parameters"):
if hasattr(self, "unknown_parameters"):
variables += self.parameters
if hasattr(self, "custom_variables"):
variables += self.custom_variables
return variables
@input_variables.setter
def input_variables(self, variables):
raise RuntimeError
@property
def domain(self):
@abstractmethod
def domains(self):
"""
The domain(s) where the conditions of the AbstractProblem are valid.
If more than one domain type is passed, a list of Location is
@@ -80,27 +96,7 @@ class AbstractProblem(metaclass=ABCMeta):
:return: the domain(s) of ``self``
:rtype: list[Location]
"""
domains = [
getattr(self, f"{t}_domain")
for t in ["spatial", "temporal", "parameter"]
if hasattr(self, f"{t}_domain")
]
if len(domains) == 1:
return domains[0]
elif len(domains) == 0:
raise RuntimeError
if len(set(map(type, domains))) == 1:
domain = domains[0].__class__({})
[domain.update(d) for d in domains]
return domain
else:
raise RuntimeError("different domains")
@input_variables.setter
def input_variables(self, variables):
raise RuntimeError
pass
@property
@abstractmethod
@@ -116,7 +112,9 @@ class AbstractProblem(metaclass=ABCMeta):
"""
The conditions of the problem.
"""
pass
return self._conditions
def _span_condition_points(self):
"""
@@ -281,28 +279,4 @@ class AbstractProblem(metaclass=ABCMeta):
# merging
merged_pts = torch.vstack([old_pts, new_pts])
merged_pts.labels = old_pts.labels
self.input_pts[location] = merged_pts
@property
def have_sampled_points(self):
"""
Check if all points for
``Location`` are sampled.
"""
return all(self._have_sampled_points.values())
@property
def not_sampled_points(self):
"""
Check which points are
not sampled.
"""
# variables which are not sampled
not_sampled = None
if self.have_sampled_points is False:
# check which one are not sampled:
not_sampled = []
for condition_name, is_sample in self._have_sampled_points.items():
if not is_sample:
not_sampled.append(condition_name)
return not_sampled
self.input_pts[location] = merged_pts

View File

@@ -205,6 +205,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
# put everything in a list if only one input
if not isinstance(model, list):
model = [model]
if not isinstance(scheduler, list):
scheduler = [scheduler]
if not isinstance(optimizer, list):
optimizer = [optimizer]

View File

@@ -82,6 +82,7 @@ class SupervisedSolver(SolverInterface):
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
self.loss = loss
def forward(self, x):
"""Forward pass implementation for the solver.
@@ -90,7 +91,16 @@ class SupervisedSolver(SolverInterface):
:return: Solver solution.
:rtype: torch.Tensor
"""
return self._pina_model(x)
output = self._pina_model[0](x)
output.labels = {
1: {
"name": "output",
"dof": self.problem.output_variables
}
}
return output
def configure_optimizers(self):
"""Optimizer configuration for the solver.
@@ -98,9 +108,12 @@ class SupervisedSolver(SolverInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
self._pina_optimizer.hook(self._pina_model.parameters())
self._pina_scheduler.hook(self._pina_optimizer)
return self._pina_optimizer, self._pina_scheduler
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
self._pina_scheduler[0].hook(self._pina_optimizer[0])
return (
[self._pina_optimizer[0].optimizer_instance],
[self._pina_scheduler[0].scheduler_instance]
)
def training_step(self, batch, batch_idx):
"""Solver training step.
@@ -113,14 +126,16 @@ class SupervisedSolver(SolverInterface):
:rtype: LabelTensor
"""
condition_idx = batch["condition"]
condition_idx = batch.condition
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch["pts"]
out = batch["output"]
pts = batch.input
out = batch.output
print(out)
print(pts)
if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.")
@@ -134,9 +149,11 @@ class SupervisedSolver(SolverInterface):
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]
input_pts.labels = pts.labels
output_pts.labels = out.labels
loss = (
self.loss_data(input_pts=input_pts, output_pts=output_pts)
* condition.data_weight
)
loss = loss.as_subclass(torch.Tensor)
@@ -155,6 +172,10 @@ class SupervisedSolver(SolverInterface):
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
print(input_pts)
print(output_pts)
print(self.loss)
print(self.forward(input_pts))
return self.loss(self.forward(input_pts), output_pts)
@property

View File

@@ -3,7 +3,7 @@
import torch
import pytorch_lightning
from .utils import check_consistency
from .data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from .data import SamplePointDataset, SamplePointLoader, DataPointDataset
from .solvers.solver import SolverInterface
@@ -35,19 +35,33 @@ class Trainer(pytorch_lightning.Trainer):
self._model = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
# create dataloader
if solver.problem.have_sampled_points is False:
raise RuntimeError(
f"Input points in {solver.problem.not_sampled_points} "
"training are None. Please "
"sample points in your problem by calling "
"discretise_domain function before train "
"in the provided locations."
)
# if solver.problem.have_sampled_points is False:
# raise RuntimeError(
# f"Input points in {solver.problem.not_sampled_points} "
# "training are None. Please "
# "sample points in your problem by calling "
# "discretise_domain function before train "
# "in the provided locations."
# )
self._create_or_update_loader()
# self._create_or_update_loader()
def _create_or_update_loader(self):
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self._model.problem
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device)
)
def _create_loader(self):
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
@@ -64,12 +78,6 @@ class Trainer(pytorch_lightning.Trainer):
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
)
pb = self._model.problem
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device)
)
def train(self, **kwargs):
"""