add dataset and dataloader for sample points (#195)
* add dataset and dataloader for sample points * unittests
This commit is contained in:
@@ -15,4 +15,7 @@ from .trainer import Trainer
|
||||
from .plotter import Plotter
|
||||
from .condition import Condition
|
||||
from .geometry import Location
|
||||
from .geometry import CartesianDomain
|
||||
from .geometry import CartesianDomain
|
||||
|
||||
from .dataset import SamplePointDataset
|
||||
from .dataset import SamplePointLoader
|
||||
@@ -1,6 +1,7 @@
|
||||
'''PINA Callbacks Implementations'''
|
||||
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
# from lightning.pytorch.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
import torch
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'''PINA Callbacks Implementations'''
|
||||
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
import torch
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
288
pina/dataset.py
288
pina/dataset.py
@@ -1,78 +1,240 @@
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import functools
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from pina import LabelTensor
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
class SamplePointDataset(Dataset):
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, pinn) -> None:
|
||||
self.pinn = pinn
|
||||
|
||||
@property
|
||||
def dataloader(self):
|
||||
return self._create_dataloader()
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return [self.SampleDataset(key, val)
|
||||
for key, val in self.input_pts.items()]
|
||||
|
||||
def _create_dataloader(self):
|
||||
"""Private method for creating dataloader
|
||||
|
||||
:return: dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
def __init__(self, problem, device) -> None:
|
||||
"""
|
||||
if self.pinn.batch_size is None:
|
||||
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
:param dict input_pts: The input points.
|
||||
"""
|
||||
super().__init__()
|
||||
pts_list = []
|
||||
self.condition_names = []
|
||||
|
||||
def custom_collate(batch):
|
||||
# extracting pts labels
|
||||
_, pts = list(batch[0].items())[0]
|
||||
labels = pts.labels
|
||||
# calling default torch collate
|
||||
collate_res = default_collate(batch)
|
||||
# save collate result in dict
|
||||
res = {}
|
||||
for key, val in collate_res.items():
|
||||
val.labels = labels
|
||||
res[key] = val
|
||||
def __getitem__(self, index):
|
||||
tensor = self._tensor.select(0, index)
|
||||
return {self._location: tensor}
|
||||
for name, condition in problem.conditions.items():
|
||||
if not hasattr(condition, 'output_points'):
|
||||
pts_list.append(problem.input_pts[name])
|
||||
self.condition_names.append(name)
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
self.pts = LabelTensor.vstack(pts_list)
|
||||
|
||||
if self.pts != []:
|
||||
self.condition_indeces = torch.cat([
|
||||
torch.tensor([i]*len(pts_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
], dim=0)
|
||||
else: # if there are no sample points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.pts = torch.tensor([])
|
||||
|
||||
self.pts = self.pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.pts.shape[0]
|
||||
|
||||
|
||||
class DataPointDataset(Dataset):
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
super().__init__()
|
||||
input_list = []
|
||||
output_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if hasattr(condition, 'output_points'):
|
||||
input_list.append(problem.conditions[name].input_points)
|
||||
output_list.append(problem.conditions[name].output_points)
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.input_pts = LabelTensor.vstack(input_list)
|
||||
self.output_pts = LabelTensor.vstack(output_list)
|
||||
|
||||
if self.input_pts != []:
|
||||
self.condition_indeces = torch.cat([
|
||||
torch.tensor([i]*len(input_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
], dim=0)
|
||||
else: # if there are no data points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.input_pts = torch.tensor([])
|
||||
self.output_pts = torch.tensor([])
|
||||
|
||||
self.input_pts = self.input_pts.to(device)
|
||||
self.output_pts = self.output_pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.input_pts.shape[0]
|
||||
|
||||
|
||||
class SamplePointLoader:
|
||||
"""
|
||||
This class is used to create a dataloader to use during the training.
|
||||
|
||||
# TODO: working also for datapoints
|
||||
class DummyLoader:
|
||||
:var condition_names: The names of the conditions. The order is consistent
|
||||
with the condition indeces in the batches.
|
||||
:vartype condition_names: list[str]
|
||||
"""
|
||||
|
||||
def __init__(self, data, device) -> None:
|
||||
def __init__(self, sample_dataset, data_dataset, batch_size=None, shuffle=True) -> None:
|
||||
"""
|
||||
Constructor.
|
||||
|
||||
# TODO: We need to make a dataset somehow
|
||||
# and the PINADataset needs to have a method
|
||||
# to send points to device
|
||||
# now we simply do it here
|
||||
# send data to device
|
||||
def convert_tensors(pts, device):
|
||||
pts = pts.to(device)
|
||||
pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
return pts
|
||||
:param SamplePointDataset sample_pts: The sample points dataset.
|
||||
:param int batch_size: The batch size. If ``None``, the batch size is
|
||||
set to the number of sample points. Default is ``None``.
|
||||
:param bool shuffle: If ``True``, the sample points are shuffled.
|
||||
Default is ``True``.
|
||||
"""
|
||||
if not isinstance(sample_dataset, SamplePointDataset):
|
||||
raise TypeError(f'Expected SamplePointDataset, got {type(sample_dataset)}')
|
||||
if not isinstance(data_dataset, DataPointDataset):
|
||||
raise TypeError(f'Expected DataPointDataset, got {type(data_dataset)}')
|
||||
|
||||
for location, pts in data.items():
|
||||
if isinstance(pts, (tuple, list)):
|
||||
pts = tuple(map(functools.partial(convert_tensors, device=device),pts))
|
||||
else:
|
||||
pts = pts.to(device)
|
||||
pts = pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
|
||||
data[location] = pts
|
||||
self.n_data_conditions = len(data_dataset.condition_names)
|
||||
self.n_phys_conditions = len(sample_dataset.condition_names)
|
||||
data_dataset.condition_indeces += self.n_phys_conditions
|
||||
|
||||
# iterator
|
||||
self.data = [data]
|
||||
self._prepare_sample_dataset(sample_dataset, batch_size, shuffle)
|
||||
self._prepare_data_dataset(data_dataset, batch_size, shuffle)
|
||||
|
||||
self.condition_names = (
|
||||
sample_dataset.condition_names + data_dataset.condition_names)
|
||||
|
||||
self.batch_list = []
|
||||
for i in range(len(self.batch_sample_pts)):
|
||||
self.batch_list.append(
|
||||
('sample', i)
|
||||
)
|
||||
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_list.append(
|
||||
('data', i)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
self.random_idx = torch.randperm(len(self.batch_list))
|
||||
else:
|
||||
self.random_idx = torch.arange(len(self.batch_list))
|
||||
|
||||
|
||||
def _prepare_data_dataset(self, dataset, batch_size, shuffle):
|
||||
"""
|
||||
Prepare the dataset for data points.
|
||||
|
||||
:param SamplePointDataset dataset: The dataset.
|
||||
:param int batch_size: The batch size.
|
||||
:param bool shuffle: If ``True``, the sample points are shuffled.
|
||||
"""
|
||||
self.sample_dataset = dataset
|
||||
|
||||
if len(dataset) == 0:
|
||||
self.batch_data_conditions = []
|
||||
self.batch_input_pts = []
|
||||
self.batch_output_pts = []
|
||||
return
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = len(dataset)
|
||||
batch_num = len(dataset) // batch_size
|
||||
if len(dataset) % batch_size != 0:
|
||||
batch_num += 1
|
||||
|
||||
output_labels = dataset.output_pts.labels
|
||||
input_labels = dataset.input_pts.labels
|
||||
self.tensor_conditions = dataset.condition_indeces
|
||||
|
||||
if shuffle:
|
||||
idx = torch.randperm(dataset.input_pts.shape[0])
|
||||
self.input_pts = dataset.input_pts[idx]
|
||||
self.output_pts = dataset.output_pts[idx]
|
||||
self.tensor_conditions = dataset.condition_indeces[idx]
|
||||
|
||||
self.batch_input_pts = torch.tensor_split(
|
||||
dataset.input_pts, batch_num)
|
||||
self.batch_output_pts = torch.tensor_split(
|
||||
dataset.output_pts, batch_num)
|
||||
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_input_pts[i].labels = input_labels
|
||||
self.batch_output_pts[i].labels = output_labels
|
||||
|
||||
self.batch_data_conditions = torch.tensor_split(
|
||||
self.tensor_conditions, batch_num)
|
||||
|
||||
def _prepare_sample_dataset(self, dataset, batch_size, shuffle):
|
||||
"""
|
||||
Prepare the dataset for sample points.
|
||||
|
||||
:param DataPointDataset dataset: The dataset.
|
||||
:param int batch_size: The batch size.
|
||||
:param bool shuffle: If ``True``, the sample points are shuffled.
|
||||
"""
|
||||
|
||||
self.sample_dataset = dataset
|
||||
if len(dataset) == 0:
|
||||
self.batch_sample_conditions = []
|
||||
self.batch_sample_pts = []
|
||||
return
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = len(dataset)
|
||||
|
||||
batch_num = len(dataset) // batch_size
|
||||
if len(dataset) % batch_size != 0:
|
||||
batch_num += 1
|
||||
|
||||
self.tensor_pts = dataset.pts
|
||||
self.tensor_conditions = dataset.condition_indeces
|
||||
|
||||
# if shuffle:
|
||||
# idx = torch.randperm(self.tensor_pts.shape[0])
|
||||
# self.tensor_pts = self.tensor_pts[idx]
|
||||
# self.tensor_conditions = self.tensor_conditions[idx]
|
||||
|
||||
self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num)
|
||||
for i in range(len(self.batch_sample_pts)):
|
||||
self.batch_sample_pts[i].labels = dataset.pts.labels
|
||||
|
||||
self.batch_sample_conditions = torch.tensor_split(
|
||||
self.tensor_conditions, batch_num)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
"""
|
||||
Return an iterator over the points. Any element of the iterator is a
|
||||
dictionary with the following keys:
|
||||
- ``pts``: The input sample points. It is a LabelTensor with the
|
||||
shape ``(batch_size, input_dimension)``.
|
||||
- ``output``: The output sample points. This key is present only
|
||||
if data conditions are present. It is a LabelTensor with the
|
||||
shape ``(batch_size, output_dimension)``.
|
||||
- ``condition``: The integer condition indeces. It is a tensor
|
||||
with the shape ``(batch_size, )`` of type ``torch.int64`` and
|
||||
indicates for any ``pts`` the corresponding problem condition.
|
||||
|
||||
:return: An iterator over the points.
|
||||
:rtype: iter
|
||||
"""
|
||||
#for i in self.random_idx:
|
||||
for i in range(len(self.batch_list)):
|
||||
type_, idx_ = self.batch_list[i]
|
||||
|
||||
if type_ == 'sample':
|
||||
d = {
|
||||
'pts': self.batch_sample_pts[idx_].requires_grad_(True),
|
||||
'condition': self.batch_sample_conditions[idx_],
|
||||
}
|
||||
else:
|
||||
d = {
|
||||
'pts': self.batch_input_pts[idx_].requires_grad_(True),
|
||||
'output': self.batch_output_pts[idx_],
|
||||
'condition': self.batch_data_conditions[idx_],
|
||||
}
|
||||
yield d
|
||||
@@ -55,13 +55,15 @@ class SimplexDomain(Location):
|
||||
raise ValueError("An n-dimensional simplex is composed by n + 1 tensors of dimension n.")
|
||||
|
||||
# creating vertices matrix
|
||||
self._vertices_matrix = torch.cat(simplex_matrix)
|
||||
self._vertices_matrix.labels = matrix_labels
|
||||
self._vertices_matrix = LabelTensor.vstack(simplex_matrix)
|
||||
|
||||
# creating basis vectors for simplex
|
||||
self._vectors_shifted = (
|
||||
(self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
|
||||
)
|
||||
# self._vectors_shifted = (
|
||||
# (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
|
||||
# ) ### TODO: Remove after checking
|
||||
|
||||
vert = self._vertices_matrix
|
||||
self._vectors_shifted = (vert[:-1] - vert[-1]).T
|
||||
|
||||
# build cartesian_bound
|
||||
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
||||
@@ -114,8 +116,8 @@ class SimplexDomain(Location):
|
||||
f" expected {self.variables}."
|
||||
)
|
||||
|
||||
# shift point
|
||||
point_shift = point.T - (self._vertices_matrix.T)[:, None, -1]
|
||||
point_shift = point - self._vertices_matrix[-1]
|
||||
point_shift = point_shift.tensor.reshape(-1, 1)
|
||||
|
||||
# compute barycentric coordinates
|
||||
lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0)
|
||||
|
||||
@@ -96,6 +96,28 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
self._labels = labels # assign the label
|
||||
|
||||
@staticmethod
|
||||
def vstack(label_tensors):
|
||||
"""
|
||||
Stack tensors vertically. For more details, see
|
||||
:meth:`torch.vstack`.
|
||||
|
||||
:param list(LabelTensor) label_tensors: the tensors to stack. They need
|
||||
to have equal labels.
|
||||
:return: the stacked tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if len(label_tensors) == 0:
|
||||
return []
|
||||
|
||||
all_labels = [label for lt in label_tensors for label in lt.labels]
|
||||
if set(all_labels) != set(label_tensors[0].labels):
|
||||
raise RuntimeError('The tensors to stack have different labels')
|
||||
|
||||
labels = label_tensors[0].labels
|
||||
tensors = [lt.extract(labels) for lt in label_tensors]
|
||||
return LabelTensor(torch.vstack(tensors), labels)
|
||||
|
||||
# TODO remove try/ except thing IMPORTANT
|
||||
# make the label None of default
|
||||
def clone(self, *args, **kwargs):
|
||||
@@ -183,6 +205,18 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
return extracted_tensor
|
||||
|
||||
def detach(self):
|
||||
detached = super().detach()
|
||||
if hasattr(self, '_labels'):
|
||||
detached._labels = self._labels
|
||||
return detached
|
||||
|
||||
|
||||
def requires_grad_(self, mode = True) -> Tensor:
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self.labels
|
||||
return lt
|
||||
|
||||
def append(self, lt, mode='std'):
|
||||
"""
|
||||
Return a copy of the merged tensors.
|
||||
@@ -232,7 +266,7 @@ class LabelTensor(torch.Tensor):
|
||||
len_index = len(index)
|
||||
except TypeError:
|
||||
len_index = 1
|
||||
|
||||
|
||||
if isinstance(index, int) or len_index == 1:
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(1, -1)
|
||||
@@ -246,8 +280,14 @@ class LabelTensor(torch.Tensor):
|
||||
selected_lt.labels = [self.labels[i] for i in index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels[index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels
|
||||
|
||||
return selected_lt
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
return self.as_subclass(Tensor)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return super().__len__()
|
||||
|
||||
@@ -111,8 +111,9 @@ class LpLoss(LossInterface):
|
||||
|
||||
# check consistency
|
||||
check_consistency(p, (str,int,float))
|
||||
self.p = p
|
||||
check_consistency(relative, bool)
|
||||
|
||||
self.p = p
|
||||
self.relative = relative
|
||||
|
||||
def forward(self, input, target):
|
||||
|
||||
@@ -22,6 +22,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# varible to check if sampling is done. If no location
|
||||
# element is presented in Condition this variable is set to true
|
||||
self._have_sampled_points = {}
|
||||
for condition_name in self.conditions:
|
||||
self._have_sampled_points[condition_name] = False
|
||||
|
||||
# put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
@@ -102,15 +104,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
if hasattr(condition, 'equation') and hasattr(condition, 'input_points'):
|
||||
if hasattr(condition, 'input_points'):
|
||||
samples = condition.input_points
|
||||
elif hasattr(condition, 'output_points') and hasattr(condition, 'input_points'):
|
||||
samples = (condition.input_points, condition.output_points)
|
||||
# skip if we need to sample
|
||||
elif hasattr(condition, 'location'):
|
||||
self._have_sampled_points[condition_name] = False
|
||||
continue
|
||||
self.input_pts[condition_name] = samples
|
||||
self.input_pts[condition_name] = samples
|
||||
self._have_sampled_points[condition_name] = True
|
||||
|
||||
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
|
||||
"""
|
||||
@@ -204,7 +201,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
Adding points to the already sampled points
|
||||
Adding points to the already sampled points.
|
||||
|
||||
:param dict new_points: a dictionary with key the location to add the points
|
||||
and values the torch.Tensor points.
|
||||
|
||||
@@ -115,12 +115,15 @@ class GAROM(SolverInterface):
|
||||
check_consistency(lambda_k, float)
|
||||
check_consistency(regularizer, bool)
|
||||
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [scheduler_generator(self.optimizers[0],
|
||||
**scheduler_generator_kwargs),
|
||||
scheduler_discriminator(self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)]
|
||||
self._schedulers = [
|
||||
scheduler_generator(
|
||||
self.optimizers[0], **scheduler_generator_kwargs),
|
||||
scheduler_discriminator(
|
||||
self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)
|
||||
]
|
||||
|
||||
# loss and writer
|
||||
self._loss = loss
|
||||
|
||||
@@ -157,6 +160,63 @@ class GAROM(SolverInterface):
|
||||
def sample(self, x):
|
||||
# sampling
|
||||
return self.generator(x)
|
||||
|
||||
def _train_generator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the generator network.
|
||||
"""
|
||||
optimizer = self.optimizer_generator
|
||||
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(snapshots, generated_snapshots)
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return r_loss, g_loss
|
||||
|
||||
def _train_discriminator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the discriminator network.
|
||||
"""
|
||||
optimizer = self.optimizer_discriminator
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([snapshots, parameters])
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, snapshots)
|
||||
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
return d_loss_real, d_loss_fake, d_loss
|
||||
|
||||
def _update_weights(self, d_loss_real, d_loss_fake):
|
||||
"""
|
||||
Private method to Update the weights of the generator and discriminator
|
||||
networks.
|
||||
"""
|
||||
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
|
||||
# Update weight term for fake samples
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
return diff
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
@@ -169,78 +229,40 @@ class GAROM(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_idx = batch['condition']
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
out = batch['output']
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
# for data driven mode
|
||||
if hasattr(condition, 'output_points'):
|
||||
|
||||
# get data
|
||||
parameters, input_pts = samples
|
||||
|
||||
# get optimizers
|
||||
opt_gen, opt_disc = self.optimizers
|
||||
|
||||
# ---------------------
|
||||
# Train Discriminator
|
||||
# ---------------------
|
||||
opt_disc.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([input_pts, parameters])
|
||||
d_fake = self.discriminator([gen_imgs.detach(), parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, input_pts)
|
||||
d_loss_fake = self._loss(d_fake, gen_imgs.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward()
|
||||
opt_disc.step()
|
||||
|
||||
# -----------------
|
||||
# Train Generator
|
||||
# -----------------
|
||||
opt_gen.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(input_pts, gen_imgs)
|
||||
d_fake = self.discriminator([gen_imgs, parameters])
|
||||
g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# ----------------
|
||||
# Update weights
|
||||
# ----------------
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
|
||||
# Update weight term for fake samples
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
||||
|
||||
else:
|
||||
if not hasattr(condition, 'output_points'):
|
||||
raise NotImplementedError('GAROM works only in data-driven mode.')
|
||||
|
||||
# get data
|
||||
snapshots = out[condition_idx == condition_id]
|
||||
parameters = pts[condition_idx == condition_id]
|
||||
|
||||
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
|
||||
parameters, snapshots)
|
||||
|
||||
r_loss, g_loss = self._train_generator(parameters, snapshots)
|
||||
|
||||
diff = self._update_weights(d_loss_real, d_loss_fake)
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
||||
|
||||
return
|
||||
|
||||
@property
|
||||
|
||||
@@ -97,6 +97,15 @@ class PINN(SolverInterface):
|
||||
"""
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
def _loss_data(self, input, output):
|
||||
return self.loss(self.forward(input), output)
|
||||
|
||||
|
||||
def _loss_phys(self, samples, equation):
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
return self.loss(torch.zeros_like(residual, requires_grad=True), residual)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
|
||||
@@ -108,25 +117,29 @@ class PINN(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_losses = []
|
||||
condition_names = []
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
condition_idx = batch['condition']
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
|
||||
condition_names.append(condition_name)
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
|
||||
# PINN loss: equation evaluated on location or input_points
|
||||
if hasattr(condition, 'equation'):
|
||||
target = condition.equation.residual(samples, self.forward(samples))
|
||||
loss = self.loss(torch.zeros_like(target), target)
|
||||
# PINN loss: evaluate model(input_points) vs output_points
|
||||
elif hasattr(condition, 'output_points'):
|
||||
input_pts, output_pts = samples
|
||||
loss = self.loss(self.forward(input_pts), output_pts)
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self._loss_phys(pts, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch['output'][condition_idx == condition_id]
|
||||
loss = self._loss_data(samples, ground_truth)
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
loss = loss
|
||||
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
|
||||
@@ -135,8 +148,8 @@ class PINN(SolverInterface):
|
||||
total_loss = sum(condition_losses)
|
||||
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
|
||||
for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
# for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import lightning.pytorch as pl
|
||||
import pytorch_lightning as pl
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
import torch
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
""" Solver module. """
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from .utils import check_consistency
|
||||
from .dataset import DummyLoader
|
||||
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pl.Trainer):
|
||||
class Trainer(Trainer):
|
||||
|
||||
def __init__(self, solver, **kwargs):
|
||||
def __init__(self, solver, batch_size=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# check inheritance consistency for solver
|
||||
check_consistency(solver, SolverInterface)
|
||||
self._model = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
# create dataloader
|
||||
if solver.problem.have_sampled_points is False:
|
||||
@@ -22,19 +23,31 @@ class Trainer(pl.Trainer):
|
||||
'discretise_domain function before train '
|
||||
'in the provided locations.')
|
||||
|
||||
# TODO: make a better dataloader for train
|
||||
self._create_or_update_loader()
|
||||
|
||||
# this method is used here because is resampling is needed
|
||||
# during training, there is no need to define to touch the
|
||||
# trainer dataloader, just call the method.
|
||||
def _create_or_update_loader(self):
|
||||
# get accellerator
|
||||
device = self._accelerator_connector._accelerator_flag
|
||||
self._loader = DummyLoader(self._model.problem.input_pts, device)
|
||||
"""
|
||||
This method is used here because is resampling is needed
|
||||
during training, there is no need to define to touch the
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
|
||||
return super().fit(self._model, self._loader, **kwargs)
|
||||
if len(devices) > 1:
|
||||
raise RuntimeError('Parallel training is not supported yet.')
|
||||
|
||||
device = devices[0]
|
||||
dataset_phys = SamplePointDataset(self._model.problem, device)
|
||||
dataset_data = DataPointDataset(self._model.problem, device)
|
||||
self._loader = SamplePointLoader(
|
||||
dataset_phys, dataset_data, batch_size=self.batch_size,
|
||||
shuffle=True)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Train the solver.
|
||||
"""
|
||||
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
|
||||
2
setup.py
2
setup.py
@@ -15,7 +15,7 @@ VERSION = meta['__version__']
|
||||
KEYWORDS = 'physics-informed neural-network'
|
||||
|
||||
REQUIRED = [
|
||||
'numpy', 'matplotlib', 'torch', 'lightning'
|
||||
'numpy', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning'
|
||||
]
|
||||
|
||||
EXTRAS = {
|
||||
|
||||
@@ -44,9 +44,9 @@ class Poisson(SpatialProblem):
|
||||
'D': Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||
equation=my_laplace),
|
||||
'data': Condition(
|
||||
input_points=in_,
|
||||
output_points=out_)
|
||||
# 'data': Condition(
|
||||
# input_points=in_,
|
||||
# output_points=out_)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -44,9 +44,9 @@ class Poisson(SpatialProblem):
|
||||
'D': Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||
equation=my_laplace),
|
||||
'data': Condition(
|
||||
input_points=in_,
|
||||
output_points=out_)
|
||||
# 'data': Condition(
|
||||
# input_points=in_,
|
||||
# output_points=out_)
|
||||
}
|
||||
|
||||
|
||||
|
||||
122
tests/test_dataset.py
Normal file
122
tests/test_dataset.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from pina import LabelTensor, Condition
|
||||
from pina.equation import Equation
|
||||
from pina.geometry import CartesianDomain
|
||||
from pina.problem import SpatialProblem
|
||||
from pina.model import FeedForward
|
||||
from pina.operators import laplacian
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
|
||||
torch.sin(input_.extract(['y'])*torch.pi))
|
||||
delta_u = laplacian(output_.extract(['u']), input_)
|
||||
return delta_u - force_term
|
||||
|
||||
my_laplace = Equation(laplace_equation)
|
||||
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
|
||||
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
||||
|
||||
class Poisson(SpatialProblem):
|
||||
output_variables = ['u']
|
||||
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||
|
||||
conditions = {
|
||||
'gamma1': Condition(
|
||||
location=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma2': Condition(
|
||||
location=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma3': Condition(
|
||||
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma4': Condition(
|
||||
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||
equation=FixedValue(0.0)),
|
||||
'D': Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||
equation=my_laplace),
|
||||
'data': Condition(
|
||||
input_points=in_,
|
||||
output_points=out_),
|
||||
'data2': Condition(
|
||||
input_points=in2_,
|
||||
output_points=out2_)
|
||||
}
|
||||
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
poisson = Poisson()
|
||||
poisson.discretise_domain(10, 'grid', locations=boundaries)
|
||||
|
||||
def test_sample():
|
||||
sample_dataset = SamplePointDataset(poisson, device='cpu')
|
||||
assert len(sample_dataset) == 140
|
||||
assert sample_dataset.pts.shape == (140, 2)
|
||||
assert sample_dataset.pts.labels == ['x', 'y']
|
||||
assert sample_dataset.condition_indeces.dtype == torch.int64
|
||||
assert sample_dataset.condition_indeces.max() == torch.tensor(4)
|
||||
assert sample_dataset.condition_indeces.min() == torch.tensor(0)
|
||||
|
||||
def test_data():
|
||||
dataset = DataPointDataset(poisson, device='cpu')
|
||||
assert len(dataset) == 61
|
||||
assert dataset.input_pts.shape == (61, 2)
|
||||
assert dataset.input_pts.labels == ['x', 'y']
|
||||
assert dataset.output_pts.shape == (61, 1 )
|
||||
assert dataset.output_pts.labels == ['u']
|
||||
assert dataset.condition_indeces.dtype == torch.int64
|
||||
assert dataset.condition_indeces.max() == torch.tensor(1)
|
||||
assert dataset.condition_indeces.min() == torch.tensor(0)
|
||||
|
||||
def test_loader():
|
||||
sample_dataset = SamplePointDataset(poisson, device='cpu')
|
||||
data_dataset = DataPointDataset(poisson, device='cpu')
|
||||
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
|
||||
|
||||
for batch in loader:
|
||||
assert len(batch) in [2, 3]
|
||||
assert batch['pts'].shape[0] <= 10
|
||||
assert batch['pts'].requires_grad == True
|
||||
assert batch['pts'].labels == ['x', 'y']
|
||||
|
||||
loader2 = SamplePointLoader(sample_dataset, data_dataset, batch_size=None)
|
||||
assert len(list(loader2)) == 2
|
||||
|
||||
def test_loader2():
|
||||
poisson2 = Poisson()
|
||||
del poisson.conditions['data2']
|
||||
del poisson2.conditions['data']
|
||||
poisson2.discretise_domain(10, 'grid', locations=boundaries)
|
||||
sample_dataset = SamplePointDataset(poisson, device='cpu')
|
||||
data_dataset = DataPointDataset(poisson, device='cpu')
|
||||
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
|
||||
|
||||
for batch in loader:
|
||||
assert len(batch) == 2 # only phys condtions
|
||||
assert batch['pts'].shape[0] <= 10
|
||||
assert batch['pts'].requires_grad == True
|
||||
assert batch['pts'].labels == ['x', 'y']
|
||||
|
||||
def test_loader3():
|
||||
poisson2 = Poisson()
|
||||
del poisson.conditions['gamma1']
|
||||
del poisson.conditions['gamma2']
|
||||
del poisson.conditions['gamma3']
|
||||
del poisson.conditions['gamma4']
|
||||
del poisson.conditions['D']
|
||||
sample_dataset = SamplePointDataset(poisson, device='cpu')
|
||||
data_dataset = DataPointDataset(poisson, device='cpu')
|
||||
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
|
||||
|
||||
for batch in loader:
|
||||
assert len(batch) == 2 # only phys condtions
|
||||
assert batch['pts'].shape[0] <= 10
|
||||
assert batch['pts'].requires_grad == True
|
||||
assert batch['pts'].labels == ['x', 'y']
|
||||
@@ -95,10 +95,14 @@ def test_getitem():
|
||||
def test_getitem2():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5]
|
||||
|
||||
assert tensor_view.labels == labels
|
||||
assert torch.allclose(tensor_view, data[:5])
|
||||
|
||||
idx = torch.randperm(tensor.shape[0])
|
||||
tensor_view = tensor[idx]
|
||||
assert tensor_view.labels == labels
|
||||
|
||||
|
||||
def test_slice():
|
||||
tensor = LabelTensor(data, labels)
|
||||
tensor_view = tensor[:5, :2]
|
||||
|
||||
@@ -134,7 +134,7 @@ def test_train_cpu():
|
||||
hidden_dimension=64)
|
||||
)
|
||||
|
||||
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu')
|
||||
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu', batch_size=20)
|
||||
trainer.train()
|
||||
|
||||
def test_sample():
|
||||
|
||||
@@ -22,6 +22,8 @@ def laplace_equation(input_, output_):
|
||||
my_laplace = Equation(laplace_equation)
|
||||
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
|
||||
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
||||
|
||||
class Poisson(SpatialProblem):
|
||||
output_variables = ['u']
|
||||
@@ -45,7 +47,10 @@ class Poisson(SpatialProblem):
|
||||
equation=my_laplace),
|
||||
'data': Condition(
|
||||
input_points=in_,
|
||||
output_points=out_)
|
||||
output_points=out_),
|
||||
'data2': Condition(
|
||||
input_points=in2_,
|
||||
output_points=out2_)
|
||||
}
|
||||
|
||||
def poisson_sol(self, pts):
|
||||
@@ -92,7 +97,7 @@ def test_train_cpu():
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||
trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cpu', batch_size=20)
|
||||
trainer.train()
|
||||
|
||||
def test_train_restore():
|
||||
@@ -106,7 +111,7 @@ def test_train_restore():
|
||||
trainer.train()
|
||||
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||
t = ntrainer.train(
|
||||
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
|
||||
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
|
||||
import shutil
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@@ -121,7 +126,7 @@ def test_train_load():
|
||||
default_root_dir=tmpdir)
|
||||
trainer.train()
|
||||
new_pinn = PINN.load_from_checkpoint(
|
||||
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
||||
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||
problem = poisson_problem, model=model)
|
||||
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||
|
||||
Reference in New Issue
Block a user