supervised working
This commit is contained in:
@@ -15,11 +15,8 @@ from .label_tensor import LabelTensor
|
||||
from .solvers.solver import SolverInterface
|
||||
from .trainer import Trainer
|
||||
from .plotter import Plotter
|
||||
from .condition import Condition
|
||||
from .dataset import SamplePointDataset
|
||||
from .dataset import SamplePointLoader
|
||||
from .optimizer import TorchOptimizer
|
||||
from .scheduler import TorchScheduler
|
||||
from .condition.condition import Condition
|
||||
from .data.dataset import SamplePointDataset
|
||||
from .data.dataset import SamplePointLoader
|
||||
from .data import SamplePointDataset
|
||||
from .data import SamplePointLoader
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
__all__ = [
|
||||
'Condition',
|
||||
'ConditionInterface',
|
||||
'InputOutputCondition',
|
||||
'InputEquationCondition'
|
||||
'LocationEquationCondition',
|
||||
'DomainOutputCondition',
|
||||
'DomainEquationCondition'
|
||||
]
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
from .input_output_condition import InputOutputCondition
|
||||
from .domain_output_condition import DomainOutputCondition
|
||||
from .domain_equation_condition import DomainEquationCondition
|
||||
@@ -1,9 +1,11 @@
|
||||
""" Condition module. """
|
||||
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..geometry import Location
|
||||
from ..domain import DomainInterface
|
||||
from ..equation.equation import Equation
|
||||
|
||||
from . import DomainOutputCondition, DomainEquationCondition
|
||||
|
||||
|
||||
def dummy(a):
|
||||
"""Dummy function for testing purposes."""
|
||||
@@ -51,14 +53,6 @@ class Condition:
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
"input_points",
|
||||
"output_points",
|
||||
"location",
|
||||
"equation",
|
||||
"data_weight",
|
||||
]
|
||||
|
||||
# def _dictvalue_isinstance(self, dict_, key_, class_):
|
||||
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
||||
# if key_ not in dict_.keys():
|
||||
@@ -77,11 +71,17 @@ class Condition:
|
||||
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
|
||||
# )
|
||||
|
||||
from . import InputOutputCondition
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
||||
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
|
||||
return InputOutputCondition(**kwargs)
|
||||
return DomainOutputCondition(
|
||||
domain=kwargs["input_points"],
|
||||
output_points=kwargs["output_points"]
|
||||
)
|
||||
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
|
||||
return DomainOutputCondition(**kwargs)
|
||||
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
|
||||
return DomainEquationCondition(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._problem = None
|
||||
|
||||
@abstractmethod
|
||||
def residual(self, model):
|
||||
"""
|
||||
|
||||
@@ -1,26 +1,34 @@
|
||||
|
||||
from . import ConditionInterface
|
||||
|
||||
class InputOutputCondition(ConditionInterface):
|
||||
class DomainOutputCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for input/output data.
|
||||
"""
|
||||
|
||||
__slots__ = ["input_points", "output_points"]
|
||||
__slots__ = ["domain", "output_points"]
|
||||
|
||||
def __init__(self, input_points, output_points):
|
||||
def __init__(self, domain, output_points):
|
||||
"""
|
||||
Constructor for the `InputOutputCondition` class.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_points = input_points
|
||||
print(self)
|
||||
self.domain = domain
|
||||
self.output_points = output_points
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
"""
|
||||
Get the input points of the condition.
|
||||
"""
|
||||
return self._problem.domains[self.domain]
|
||||
|
||||
def residual(self, model):
|
||||
"""
|
||||
Compute the residual of the condition.
|
||||
"""
|
||||
return self.batch_residual(model, self.input_points, self.output_points)
|
||||
return self.batch_residual(model, self.domain, self.output_points)
|
||||
|
||||
@staticmethod
|
||||
def batch_residual(model, input_points, output_points):
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
from . import ConditionInterface
|
||||
|
||||
class InputOutputCondition(ConditionInterface):
|
||||
class InputEquationCondition(ConditionInterface):
|
||||
"""
|
||||
Condition for input/output data.
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
__all__ = [
|
||||
]
|
||||
|
||||
from .pina_dataloader import SamplePointLoader
|
||||
from .data_dataset import DataPointDataset
|
||||
from .sample_dataset import SamplePointDataset
|
||||
from .pina_batch import Batch
|
||||
41
pina/data/data_dataset.py
Normal file
41
pina/data/data_dataset.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class DataPointDataset(Dataset):
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
super().__init__()
|
||||
input_list = []
|
||||
output_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if hasattr(condition, "output_points"):
|
||||
input_list.append(problem.conditions[name].input_points)
|
||||
output_list.append(problem.conditions[name].output_points)
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.input_pts = LabelTensor.stack(input_list)
|
||||
self.output_pts = LabelTensor.stack(output_list)
|
||||
|
||||
if self.input_pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(input_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no data points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.input_pts = torch.tensor([])
|
||||
self.output_pts = torch.tensor([])
|
||||
|
||||
self.input_pts = self.input_pts.to(device)
|
||||
self.output_pts = self.output_pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.input_pts.shape[0]
|
||||
36
pina/data/pina_batch.py
Normal file
36
pina/data/pina_batch.py
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
|
||||
class Batch:
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, type_, idx, *args, **kwargs) -> None:
|
||||
"""
|
||||
"""
|
||||
if type_ == "sample":
|
||||
|
||||
if len(args) != 2:
|
||||
raise RuntimeError
|
||||
|
||||
input = args[0]
|
||||
conditions = args[1]
|
||||
|
||||
self.input = input[idx]
|
||||
self.condition = conditions[idx]
|
||||
|
||||
elif type_ == "data":
|
||||
|
||||
if len(args) != 3:
|
||||
raise RuntimeError
|
||||
|
||||
input = args[0]
|
||||
output = args[1]
|
||||
conditions = args[2]
|
||||
|
||||
self.input = input[idx]
|
||||
self.output = output[idx]
|
||||
self.condition = conditions[idx]
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid number of arguments.")
|
||||
@@ -1,84 +1,8 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class SamplePointDataset(Dataset):
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
"""
|
||||
:param dict input_pts: The input points.
|
||||
"""
|
||||
super().__init__()
|
||||
pts_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if not hasattr(condition, "output_points"):
|
||||
pts_list.append(problem.input_pts[name])
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.pts = LabelTensor.vstack(pts_list)
|
||||
|
||||
if self.pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(pts_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no sample points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.pts = torch.tensor([])
|
||||
|
||||
self.pts = self.pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.pts.shape[0]
|
||||
|
||||
|
||||
class DataPointDataset(Dataset):
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
super().__init__()
|
||||
input_list = []
|
||||
output_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if hasattr(condition, "output_points"):
|
||||
input_list.append(problem.conditions[name].input_points)
|
||||
output_list.append(problem.conditions[name].output_points)
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.input_pts = LabelTensor.vstack(input_list)
|
||||
self.output_pts = LabelTensor.vstack(output_list)
|
||||
|
||||
if self.input_pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(input_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no data points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.input_pts = torch.tensor([])
|
||||
self.output_pts = torch.tensor([])
|
||||
|
||||
self.input_pts = self.input_pts.to(device)
|
||||
self.output_pts = self.output_pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.input_pts.shape[0]
|
||||
|
||||
from .sample_dataset import SamplePointDataset
|
||||
from .data_dataset import DataPointDataset
|
||||
from .pina_batch import Batch
|
||||
|
||||
class SamplePointLoader:
|
||||
"""
|
||||
@@ -133,6 +57,8 @@ class SamplePointLoader:
|
||||
else:
|
||||
self.random_idx = torch.arange(len(self.batch_list))
|
||||
|
||||
self._prepare_batches()
|
||||
|
||||
def _prepare_data_dataset(self, dataset, batch_size, shuffle):
|
||||
"""
|
||||
Prepare the dataset for data points.
|
||||
@@ -169,7 +95,7 @@ class SamplePointLoader:
|
||||
self.batch_output_pts = torch.tensor_split(
|
||||
dataset.output_pts, batch_num
|
||||
)
|
||||
|
||||
print(input_labels)
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_input_pts[i].labels = input_labels
|
||||
self.batch_output_pts[i].labels = output_labels
|
||||
@@ -216,6 +142,29 @@ class SamplePointLoader:
|
||||
self.tensor_conditions, batch_num
|
||||
)
|
||||
|
||||
def _prepare_batches(self):
|
||||
"""
|
||||
Prepare the batches.
|
||||
"""
|
||||
self.batches = []
|
||||
for i in range(len(self.batch_list)):
|
||||
type_, idx_ = self.batch_list[i]
|
||||
|
||||
if type_ == "sample":
|
||||
batch = Batch(
|
||||
"sample", idx_,
|
||||
self.batch_sample_pts,
|
||||
self.batch_sample_conditions)
|
||||
else:
|
||||
batch = Batch(
|
||||
"data", idx_,
|
||||
self.batch_input_pts,
|
||||
self.batch_output_pts,
|
||||
self.batch_data_conditions)
|
||||
print(batch.input.labels)
|
||||
|
||||
self.batches.append(batch)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Return an iterator over the points. Any element of the iterator is a
|
||||
@@ -233,21 +182,24 @@ class SamplePointLoader:
|
||||
:rtype: iter
|
||||
"""
|
||||
# for i in self.random_idx:
|
||||
for i in range(len(self.batch_list)):
|
||||
type_, idx_ = self.batch_list[i]
|
||||
for i in self.random_idx:
|
||||
yield self.batches[i]
|
||||
|
||||
if type_ == "sample":
|
||||
d = {
|
||||
"pts": self.batch_sample_pts[idx_].requires_grad_(True),
|
||||
"condition": self.batch_sample_conditions[idx_],
|
||||
}
|
||||
else:
|
||||
d = {
|
||||
"pts": self.batch_input_pts[idx_].requires_grad_(True),
|
||||
"output": self.batch_output_pts[idx_],
|
||||
"condition": self.batch_data_conditions[idx_],
|
||||
}
|
||||
yield d
|
||||
# for i in range(len(self.batch_list)):
|
||||
# type_, idx_ = self.batch_list[i]
|
||||
|
||||
# if type_ == "sample":
|
||||
# d = {
|
||||
# "pts": self.batch_sample_pts[idx_].requires_grad_(True),
|
||||
# "condition": self.batch_sample_conditions[idx_],
|
||||
# }
|
||||
# else:
|
||||
# d = {
|
||||
# "pts": self.batch_input_pts[idx_].requires_grad_(True),
|
||||
# "output": self.batch_output_pts[idx_],
|
||||
# "condition": self.batch_data_conditions[idx_],
|
||||
# }
|
||||
# yield d
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
43
pina/data/sample_dataset.py
Normal file
43
pina/data/sample_dataset.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class SamplePointDataset(Dataset):
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
"""
|
||||
:param dict input_pts: The input points.
|
||||
"""
|
||||
super().__init__()
|
||||
pts_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if not hasattr(condition, "output_points"):
|
||||
pts_list.append(problem.input_pts[name])
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.pts = LabelTensor.stack(pts_list)
|
||||
|
||||
if self.pts != []:
|
||||
self.condition_indeces = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(pts_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else: # if there are no sample points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.pts = torch.tensor([])
|
||||
|
||||
self.pts = self.pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.pts.shape[0]
|
||||
@@ -1,5 +1,5 @@
|
||||
__all__ = [
|
||||
"Location",
|
||||
"DomainInterface",
|
||||
"CartesianDomain",
|
||||
"EllipsoidDomain",
|
||||
"Union",
|
||||
@@ -10,7 +10,7 @@ __all__ = [
|
||||
"SimplexDomain",
|
||||
]
|
||||
|
||||
from .location import Location
|
||||
from .domain_interface import DomainInterface
|
||||
from .cartesian import CartesianDomain
|
||||
from .ellipsoid import EllipsoidDomain
|
||||
from .exclusion_domain import Exclusion
|
||||
@@ -1,11 +1,11 @@
|
||||
import torch
|
||||
|
||||
from .location import Location
|
||||
from .domain_interface import DomainInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import torch_lhs, chebyshev_roots
|
||||
|
||||
|
||||
class CartesianDomain(Location):
|
||||
class CartesianDomain(DomainInterface):
|
||||
"""PINA implementation of Hypercube domain."""
|
||||
|
||||
def __init__(self, cartesian_dict):
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Module for Location class."""
|
||||
"""Module for the DomainInterface class."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class Location(metaclass=ABCMeta):
|
||||
class DomainInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract Location class.
|
||||
Any geometry entity should inherit from this class.
|
||||
@@ -1,11 +1,11 @@
|
||||
import torch
|
||||
|
||||
from .location import Location
|
||||
from .domain_interface import DomainInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class EllipsoidDomain(Location):
|
||||
class EllipsoidDomain(DomainInterface):
|
||||
"""PINA implementation of Ellipsoid domain."""
|
||||
|
||||
def __init__(self, ellipsoid_dict, sample_surface=False):
|
||||
@@ -1,11 +1,11 @@
|
||||
""" Module for OperationInterface class. """
|
||||
|
||||
from .location import Location
|
||||
from .domain_interface import DomainInterface
|
||||
from ..utils import check_consistency
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class OperationInterface(Location, metaclass=ABCMeta):
|
||||
class OperationInterface(DomainInterface, metaclass=ABCMeta):
|
||||
|
||||
def __init__(self, geometries):
|
||||
"""
|
||||
@@ -15,7 +15,7 @@ class OperationInterface(Location, metaclass=ABCMeta):
|
||||
such as ``EllipsoidDomain`` or ``CartesianDomain``.
|
||||
"""
|
||||
# check consistency geometries
|
||||
check_consistency(geometries, Location)
|
||||
check_consistency(geometries, DomainInterface)
|
||||
|
||||
# check we are passing always different
|
||||
# geometries with the same labels.
|
||||
@@ -1,11 +1,11 @@
|
||||
import torch
|
||||
from .location import Location
|
||||
from pina.geometry import CartesianDomain
|
||||
from .domain_interface import DomainInterface
|
||||
from pina.domain import CartesianDomain
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class SimplexDomain(Location):
|
||||
class SimplexDomain(DomainInterface):
|
||||
"""PINA implementation of a Simplex."""
|
||||
|
||||
def __init__(self, simplex_matrix, sample_surface=False):
|
||||
@@ -229,10 +229,6 @@ from torch import Tensor
|
||||
# detached._labels = self._labels
|
||||
# return detached
|
||||
|
||||
# def requires_grad_(self, mode=True):
|
||||
# lt = super().requires_grad_(mode)
|
||||
# lt.labels = self.labels
|
||||
# return lt
|
||||
|
||||
# def append(self, lt, mode="std"):
|
||||
# """
|
||||
@@ -406,11 +402,29 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
return LabelTensor(new_tensor, label_to_extract)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
s = ''
|
||||
for key, value in self.labels.items():
|
||||
s += f"{key}: {value}\n"
|
||||
s += '\n'
|
||||
s += super().__str__()
|
||||
return s
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def stack(tensors):
|
||||
"""
|
||||
"""
|
||||
if len(tensors) == 0:
|
||||
return []
|
||||
|
||||
if len(tensors) == 1:
|
||||
return tensors[0]
|
||||
|
||||
raise NotImplementedError
|
||||
labels = [tensor.labels for tensor in tensors]
|
||||
print(labels)
|
||||
|
||||
def requires_grad_(self, mode=True):
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self.labels
|
||||
return lt
|
||||
@@ -5,6 +5,8 @@ from ..utils import merge_tensors, check_consistency
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
from .. import LabelTensor
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
@@ -18,17 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
# variable storing all points
|
||||
self.input_pts = {}
|
||||
|
||||
# varible to check if sampling is done. If no location
|
||||
# element is presented in Condition this variable is set to true
|
||||
self._have_sampled_points = {}
|
||||
self._discretized_domains = {}
|
||||
|
||||
for name, domain in self.domains.items():
|
||||
if isinstance(domain, (torch.Tensor, LabelTensor)):
|
||||
self._discretized_domains[name] = domain
|
||||
|
||||
for condition_name in self.conditions:
|
||||
self._have_sampled_points[condition_name] = False
|
||||
self.conditions[condition_name]._problem = self
|
||||
# # variable storing all points
|
||||
# self.input_pts = {}
|
||||
|
||||
# put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
# # varible to check if sampling is done. If no location
|
||||
# # element is presented in Condition this variable is set to true
|
||||
# self._have_sampled_points = {}
|
||||
# for condition_name in self.conditions:
|
||||
# self._have_sampled_points[condition_name] = False
|
||||
|
||||
# # put in self.input_pts all the points that we don't need to sample
|
||||
# self._span_condition_points()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
@@ -63,15 +74,20 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
variables += self.spatial_variables
|
||||
if hasattr(self, "temporal_variable"):
|
||||
variables += self.temporal_variable
|
||||
if hasattr(self, "parameters"):
|
||||
if hasattr(self, "unknown_parameters"):
|
||||
variables += self.parameters
|
||||
if hasattr(self, "custom_variables"):
|
||||
variables += self.custom_variables
|
||||
|
||||
return variables
|
||||
|
||||
@input_variables.setter
|
||||
def input_variables(self, variables):
|
||||
raise RuntimeError
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
@abstractmethod
|
||||
def domains(self):
|
||||
"""
|
||||
The domain(s) where the conditions of the AbstractProblem are valid.
|
||||
If more than one domain type is passed, a list of Location is
|
||||
@@ -80,27 +96,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
:return: the domain(s) of ``self``
|
||||
:rtype: list[Location]
|
||||
"""
|
||||
domains = [
|
||||
getattr(self, f"{t}_domain")
|
||||
for t in ["spatial", "temporal", "parameter"]
|
||||
if hasattr(self, f"{t}_domain")
|
||||
]
|
||||
|
||||
if len(domains) == 1:
|
||||
return domains[0]
|
||||
elif len(domains) == 0:
|
||||
raise RuntimeError
|
||||
|
||||
if len(set(map(type, domains))) == 1:
|
||||
domain = domains[0].__class__({})
|
||||
[domain.update(d) for d in domains]
|
||||
return domain
|
||||
else:
|
||||
raise RuntimeError("different domains")
|
||||
|
||||
@input_variables.setter
|
||||
def input_variables(self, variables):
|
||||
raise RuntimeError
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -116,7 +112,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
The conditions of the problem.
|
||||
"""
|
||||
pass
|
||||
return self._conditions
|
||||
|
||||
|
||||
|
||||
def _span_condition_points(self):
|
||||
"""
|
||||
@@ -281,28 +279,4 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# merging
|
||||
merged_pts = torch.vstack([old_pts, new_pts])
|
||||
merged_pts.labels = old_pts.labels
|
||||
self.input_pts[location] = merged_pts
|
||||
|
||||
@property
|
||||
def have_sampled_points(self):
|
||||
"""
|
||||
Check if all points for
|
||||
``Location`` are sampled.
|
||||
"""
|
||||
return all(self._have_sampled_points.values())
|
||||
|
||||
@property
|
||||
def not_sampled_points(self):
|
||||
"""
|
||||
Check which points are
|
||||
not sampled.
|
||||
"""
|
||||
# variables which are not sampled
|
||||
not_sampled = None
|
||||
if self.have_sampled_points is False:
|
||||
# check which one are not sampled:
|
||||
not_sampled = []
|
||||
for condition_name, is_sample in self._have_sampled_points.items():
|
||||
if not is_sample:
|
||||
not_sampled.append(condition_name)
|
||||
return not_sampled
|
||||
self.input_pts[location] = merged_pts
|
||||
@@ -205,6 +205,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
# put everything in a list if only one input
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
if not isinstance(scheduler, list):
|
||||
scheduler = [scheduler]
|
||||
if not isinstance(optimizer, list):
|
||||
optimizer = [optimizer]
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ class SupervisedSolver(SolverInterface):
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
self.loss = loss
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -90,7 +91,16 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._pina_model(x)
|
||||
|
||||
output = self._pina_model[0](x)
|
||||
|
||||
output.labels = {
|
||||
1: {
|
||||
"name": "output",
|
||||
"dof": self.problem.output_variables
|
||||
}
|
||||
}
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the solver.
|
||||
@@ -98,9 +108,12 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
self._pina_optimizer.hook(self._pina_model.parameters())
|
||||
self._pina_scheduler.hook(self._pina_optimizer)
|
||||
return self._pina_optimizer, self._pina_scheduler
|
||||
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
|
||||
self._pina_scheduler[0].hook(self._pina_optimizer[0])
|
||||
return (
|
||||
[self._pina_optimizer[0].optimizer_instance],
|
||||
[self._pina_scheduler[0].scheduler_instance]
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Solver training step.
|
||||
@@ -113,14 +126,16 @@ class SupervisedSolver(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
condition_idx = batch["condition"]
|
||||
condition_idx = batch.condition
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
out = batch["output"]
|
||||
pts = batch.input
|
||||
out = batch.output
|
||||
print(out)
|
||||
print(pts)
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
@@ -134,9 +149,11 @@ class SupervisedSolver(SolverInterface):
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
input_pts.labels = pts.labels
|
||||
output_pts.labels = out.labels
|
||||
|
||||
loss = (
|
||||
self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
* condition.data_weight
|
||||
)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
@@ -155,6 +172,10 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
print(input_pts)
|
||||
print(output_pts)
|
||||
print(self.loss)
|
||||
print(self.forward(input_pts))
|
||||
return self.loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import pytorch_lightning
|
||||
from .utils import check_consistency
|
||||
from .data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .data import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
|
||||
@@ -35,19 +35,33 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
self._model = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
|
||||
# create dataloader
|
||||
if solver.problem.have_sampled_points is False:
|
||||
raise RuntimeError(
|
||||
f"Input points in {solver.problem.not_sampled_points} "
|
||||
"training are None. Please "
|
||||
"sample points in your problem by calling "
|
||||
"discretise_domain function before train "
|
||||
"in the provided locations."
|
||||
)
|
||||
# if solver.problem.have_sampled_points is False:
|
||||
# raise RuntimeError(
|
||||
# f"Input points in {solver.problem.not_sampled_points} "
|
||||
# "training are None. Please "
|
||||
# "sample points in your problem by calling "
|
||||
# "discretise_domain function before train "
|
||||
# "in the provided locations."
|
||||
# )
|
||||
|
||||
self._create_or_update_loader()
|
||||
# self._create_or_update_loader()
|
||||
|
||||
def _create_or_update_loader(self):
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
|
||||
# move parameters to device
|
||||
pb = self._model.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
|
||||
def _create_loader(self):
|
||||
"""
|
||||
This method is used here because is resampling is needed
|
||||
during training, there is no need to define to touch the
|
||||
@@ -64,12 +78,6 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
self._loader = SamplePointLoader(
|
||||
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
||||
)
|
||||
pb = self._model.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user