add dataset and dataloader for sample points (#195)

* add dataset and dataloader for sample points
* unittests
This commit is contained in:
Nicola Demo
2023-11-07 11:34:44 +01:00
parent cd5bc9a558
commit d654259428
19 changed files with 581 additions and 196 deletions

View File

@@ -15,4 +15,7 @@ from .trainer import Trainer
from .plotter import Plotter
from .condition import Condition
from .geometry import Location
from .geometry import CartesianDomain
from .geometry import CartesianDomain
from .dataset import SamplePointDataset
from .dataset import SamplePointLoader

View File

@@ -1,6 +1,7 @@
'''PINA Callbacks Implementations'''
from lightning.pytorch.callbacks import Callback
# from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
import torch
from ..utils import check_consistency

View File

@@ -1,6 +1,6 @@
'''PINA Callbacks Implementations'''
from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
import torch
from ..utils import check_consistency

View File

@@ -1,78 +1,240 @@
from torch.utils.data import Dataset, DataLoader
import functools
from torch.utils.data import Dataset
import torch
from pina import LabelTensor
class PinaDataset():
class SamplePointDataset(Dataset):
"""
This class is used to create a dataset of sample points.
"""
def __init__(self, pinn) -> None:
self.pinn = pinn
@property
def dataloader(self):
return self._create_dataloader()
@property
def dataset(self):
return [self.SampleDataset(key, val)
for key, val in self.input_pts.items()]
def _create_dataloader(self):
"""Private method for creating dataloader
:return: dataloader
:rtype: torch.utils.data.DataLoader
def __init__(self, problem, device) -> None:
"""
if self.pinn.batch_size is None:
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
:param dict input_pts: The input points.
"""
super().__init__()
pts_list = []
self.condition_names = []
def custom_collate(batch):
# extracting pts labels
_, pts = list(batch[0].items())[0]
labels = pts.labels
# calling default torch collate
collate_res = default_collate(batch)
# save collate result in dict
res = {}
for key, val in collate_res.items():
val.labels = labels
res[key] = val
def __getitem__(self, index):
tensor = self._tensor.select(0, index)
return {self._location: tensor}
for name, condition in problem.conditions.items():
if not hasattr(condition, 'output_points'):
pts_list.append(problem.input_pts[name])
self.condition_names.append(name)
def __len__(self):
return self._len
self.pts = LabelTensor.vstack(pts_list)
if self.pts != []:
self.condition_indeces = torch.cat([
torch.tensor([i]*len(pts_list[i]))
for i in range(len(self.condition_names))
], dim=0)
else: # if there are no sample points
self.condition_indeces = torch.tensor([])
self.pts = torch.tensor([])
self.pts = self.pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.pts.shape[0]
class DataPointDataset(Dataset):
def __init__(self, problem, device) -> None:
super().__init__()
input_list = []
output_list = []
self.condition_names = []
for name, condition in problem.conditions.items():
if hasattr(condition, 'output_points'):
input_list.append(problem.conditions[name].input_points)
output_list.append(problem.conditions[name].output_points)
self.condition_names.append(name)
self.input_pts = LabelTensor.vstack(input_list)
self.output_pts = LabelTensor.vstack(output_list)
if self.input_pts != []:
self.condition_indeces = torch.cat([
torch.tensor([i]*len(input_list[i]))
for i in range(len(self.condition_names))
], dim=0)
else: # if there are no data points
self.condition_indeces = torch.tensor([])
self.input_pts = torch.tensor([])
self.output_pts = torch.tensor([])
self.input_pts = self.input_pts.to(device)
self.output_pts = self.output_pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.input_pts.shape[0]
class SamplePointLoader:
"""
This class is used to create a dataloader to use during the training.
# TODO: working also for datapoints
class DummyLoader:
:var condition_names: The names of the conditions. The order is consistent
with the condition indeces in the batches.
:vartype condition_names: list[str]
"""
def __init__(self, data, device) -> None:
def __init__(self, sample_dataset, data_dataset, batch_size=None, shuffle=True) -> None:
"""
Constructor.
# TODO: We need to make a dataset somehow
# and the PINADataset needs to have a method
# to send points to device
# now we simply do it here
# send data to device
def convert_tensors(pts, device):
pts = pts.to(device)
pts.requires_grad_(True)
pts.retain_grad()
return pts
:param SamplePointDataset sample_pts: The sample points dataset.
:param int batch_size: The batch size. If ``None``, the batch size is
set to the number of sample points. Default is ``None``.
:param bool shuffle: If ``True``, the sample points are shuffled.
Default is ``True``.
"""
if not isinstance(sample_dataset, SamplePointDataset):
raise TypeError(f'Expected SamplePointDataset, got {type(sample_dataset)}')
if not isinstance(data_dataset, DataPointDataset):
raise TypeError(f'Expected DataPointDataset, got {type(data_dataset)}')
for location, pts in data.items():
if isinstance(pts, (tuple, list)):
pts = tuple(map(functools.partial(convert_tensors, device=device),pts))
else:
pts = pts.to(device)
pts = pts.requires_grad_(True)
pts.retain_grad()
data[location] = pts
self.n_data_conditions = len(data_dataset.condition_names)
self.n_phys_conditions = len(sample_dataset.condition_names)
data_dataset.condition_indeces += self.n_phys_conditions
# iterator
self.data = [data]
self._prepare_sample_dataset(sample_dataset, batch_size, shuffle)
self._prepare_data_dataset(data_dataset, batch_size, shuffle)
self.condition_names = (
sample_dataset.condition_names + data_dataset.condition_names)
self.batch_list = []
for i in range(len(self.batch_sample_pts)):
self.batch_list.append(
('sample', i)
)
for i in range(len(self.batch_input_pts)):
self.batch_list.append(
('data', i)
)
if shuffle:
self.random_idx = torch.randperm(len(self.batch_list))
else:
self.random_idx = torch.arange(len(self.batch_list))
def _prepare_data_dataset(self, dataset, batch_size, shuffle):
"""
Prepare the dataset for data points.
:param SamplePointDataset dataset: The dataset.
:param int batch_size: The batch size.
:param bool shuffle: If ``True``, the sample points are shuffled.
"""
self.sample_dataset = dataset
if len(dataset) == 0:
self.batch_data_conditions = []
self.batch_input_pts = []
self.batch_output_pts = []
return
if batch_size is None:
batch_size = len(dataset)
batch_num = len(dataset) // batch_size
if len(dataset) % batch_size != 0:
batch_num += 1
output_labels = dataset.output_pts.labels
input_labels = dataset.input_pts.labels
self.tensor_conditions = dataset.condition_indeces
if shuffle:
idx = torch.randperm(dataset.input_pts.shape[0])
self.input_pts = dataset.input_pts[idx]
self.output_pts = dataset.output_pts[idx]
self.tensor_conditions = dataset.condition_indeces[idx]
self.batch_input_pts = torch.tensor_split(
dataset.input_pts, batch_num)
self.batch_output_pts = torch.tensor_split(
dataset.output_pts, batch_num)
for i in range(len(self.batch_input_pts)):
self.batch_input_pts[i].labels = input_labels
self.batch_output_pts[i].labels = output_labels
self.batch_data_conditions = torch.tensor_split(
self.tensor_conditions, batch_num)
def _prepare_sample_dataset(self, dataset, batch_size, shuffle):
"""
Prepare the dataset for sample points.
:param DataPointDataset dataset: The dataset.
:param int batch_size: The batch size.
:param bool shuffle: If ``True``, the sample points are shuffled.
"""
self.sample_dataset = dataset
if len(dataset) == 0:
self.batch_sample_conditions = []
self.batch_sample_pts = []
return
if batch_size is None:
batch_size = len(dataset)
batch_num = len(dataset) // batch_size
if len(dataset) % batch_size != 0:
batch_num += 1
self.tensor_pts = dataset.pts
self.tensor_conditions = dataset.condition_indeces
# if shuffle:
# idx = torch.randperm(self.tensor_pts.shape[0])
# self.tensor_pts = self.tensor_pts[idx]
# self.tensor_conditions = self.tensor_conditions[idx]
self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num)
for i in range(len(self.batch_sample_pts)):
self.batch_sample_pts[i].labels = dataset.pts.labels
self.batch_sample_conditions = torch.tensor_split(
self.tensor_conditions, batch_num)
def __iter__(self):
return iter(self.data)
"""
Return an iterator over the points. Any element of the iterator is a
dictionary with the following keys:
- ``pts``: The input sample points. It is a LabelTensor with the
shape ``(batch_size, input_dimension)``.
- ``output``: The output sample points. This key is present only
if data conditions are present. It is a LabelTensor with the
shape ``(batch_size, output_dimension)``.
- ``condition``: The integer condition indeces. It is a tensor
with the shape ``(batch_size, )`` of type ``torch.int64`` and
indicates for any ``pts`` the corresponding problem condition.
:return: An iterator over the points.
:rtype: iter
"""
#for i in self.random_idx:
for i in range(len(self.batch_list)):
type_, idx_ = self.batch_list[i]
if type_ == 'sample':
d = {
'pts': self.batch_sample_pts[idx_].requires_grad_(True),
'condition': self.batch_sample_conditions[idx_],
}
else:
d = {
'pts': self.batch_input_pts[idx_].requires_grad_(True),
'output': self.batch_output_pts[idx_],
'condition': self.batch_data_conditions[idx_],
}
yield d

View File

@@ -55,13 +55,15 @@ class SimplexDomain(Location):
raise ValueError("An n-dimensional simplex is composed by n + 1 tensors of dimension n.")
# creating vertices matrix
self._vertices_matrix = torch.cat(simplex_matrix)
self._vertices_matrix.labels = matrix_labels
self._vertices_matrix = LabelTensor.vstack(simplex_matrix)
# creating basis vectors for simplex
self._vectors_shifted = (
(self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
)
# self._vectors_shifted = (
# (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
# ) ### TODO: Remove after checking
vert = self._vertices_matrix
self._vectors_shifted = (vert[:-1] - vert[-1]).T
# build cartesian_bound
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
@@ -114,8 +116,8 @@ class SimplexDomain(Location):
f" expected {self.variables}."
)
# shift point
point_shift = point.T - (self._vertices_matrix.T)[:, None, -1]
point_shift = point - self._vertices_matrix[-1]
point_shift = point_shift.tensor.reshape(-1, 1)
# compute barycentric coordinates
lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0)

View File

@@ -96,6 +96,28 @@ class LabelTensor(torch.Tensor):
self._labels = labels # assign the label
@staticmethod
def vstack(label_tensors):
"""
Stack tensors vertically. For more details, see
:meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: the tensors to stack. They need
to have equal labels.
:return: the stacked tensor
:rtype: LabelTensor
"""
if len(label_tensors) == 0:
return []
all_labels = [label for lt in label_tensors for label in lt.labels]
if set(all_labels) != set(label_tensors[0].labels):
raise RuntimeError('The tensors to stack have different labels')
labels = label_tensors[0].labels
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)
# TODO remove try/ except thing IMPORTANT
# make the label None of default
def clone(self, *args, **kwargs):
@@ -183,6 +205,18 @@ class LabelTensor(torch.Tensor):
return extracted_tensor
def detach(self):
detached = super().detach()
if hasattr(self, '_labels'):
detached._labels = self._labels
return detached
def requires_grad_(self, mode = True) -> Tensor:
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt
def append(self, lt, mode='std'):
"""
Return a copy of the merged tensors.
@@ -232,7 +266,7 @@ class LabelTensor(torch.Tensor):
len_index = len(index)
except TypeError:
len_index = 1
if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
@@ -246,8 +280,14 @@ class LabelTensor(torch.Tensor):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
else:
selected_lt.labels = self.labels
return selected_lt
@property
def tensor(self):
return self.as_subclass(Tensor)
def __len__(self) -> int:
return super().__len__()

View File

@@ -111,8 +111,9 @@ class LpLoss(LossInterface):
# check consistency
check_consistency(p, (str,int,float))
self.p = p
check_consistency(relative, bool)
self.p = p
self.relative = relative
def forward(self, input, target):

View File

@@ -22,6 +22,8 @@ class AbstractProblem(metaclass=ABCMeta):
# varible to check if sampling is done. If no location
# element is presented in Condition this variable is set to true
self._have_sampled_points = {}
for condition_name in self.conditions:
self._have_sampled_points[condition_name] = False
# put in self.input_pts all the points that we don't need to sample
self._span_condition_points()
@@ -102,15 +104,10 @@ class AbstractProblem(metaclass=ABCMeta):
"""
for condition_name in self.conditions:
condition = self.conditions[condition_name]
if hasattr(condition, 'equation') and hasattr(condition, 'input_points'):
if hasattr(condition, 'input_points'):
samples = condition.input_points
elif hasattr(condition, 'output_points') and hasattr(condition, 'input_points'):
samples = (condition.input_points, condition.output_points)
# skip if we need to sample
elif hasattr(condition, 'location'):
self._have_sampled_points[condition_name] = False
continue
self.input_pts[condition_name] = samples
self.input_pts[condition_name] = samples
self._have_sampled_points[condition_name] = True
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
"""
@@ -204,7 +201,7 @@ class AbstractProblem(metaclass=ABCMeta):
def add_points(self, new_points):
"""
Adding points to the already sampled points
Adding points to the already sampled points.
:param dict new_points: a dictionary with key the location to add the points
and values the torch.Tensor points.

View File

@@ -115,12 +115,15 @@ class GAROM(SolverInterface):
check_consistency(lambda_k, float)
check_consistency(regularizer, bool)
# assign schedulers
self._schedulers = [scheduler_generator(self.optimizers[0],
**scheduler_generator_kwargs),
scheduler_discriminator(self.optimizers[1],
**scheduler_discriminator_kwargs)]
self._schedulers = [
scheduler_generator(
self.optimizers[0], **scheduler_generator_kwargs),
scheduler_discriminator(
self.optimizers[1],
**scheduler_discriminator_kwargs)
]
# loss and writer
self._loss = loss
@@ -157,6 +160,63 @@ class GAROM(SolverInterface):
def sample(self, x):
# sampling
return self.generator(x)
def _train_generator(self, parameters, snapshots):
"""
Private method to train the generator network.
"""
optimizer = self.optimizer_generator
generated_snapshots = self.generator(parameters)
# generator loss
r_loss = self._loss(snapshots, generated_snapshots)
d_fake = self.discriminator([generated_snapshots, parameters])
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
# backward step
g_loss.backward()
optimizer.step()
return r_loss, g_loss
def _train_discriminator(self, parameters, snapshots):
"""
Private method to train the discriminator network.
"""
optimizer = self.optimizer_discriminator
optimizer.zero_grad()
# Generate a batch of images
generated_snapshots = self.generator(parameters)
# Discriminator pass
d_real = self.discriminator([snapshots, parameters])
d_fake = self.discriminator([generated_snapshots, parameters])
# evaluate loss
d_loss_real = self._loss(d_real, snapshots)
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
d_loss.backward(retain_graph=True)
optimizer.step()
return d_loss_real, d_loss_fake, d_loss
def _update_weights(self, d_loss_real, d_loss_fake):
"""
Private method to Update the weights of the generator and discriminator
networks.
"""
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
self.k += self.lambda_k * diff.item()
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
return diff
def training_step(self, batch, batch_idx):
"""PINN solver training step.
@@ -169,78 +229,40 @@ class GAROM(SolverInterface):
:rtype: LabelTensor
"""
for condition_name, samples in batch.items():
dataloader = self.trainer.train_dataloader
condition_idx = batch['condition']
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
out = batch['output']
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition = self.problem.conditions[condition_name]
# for data driven mode
if hasattr(condition, 'output_points'):
# get data
parameters, input_pts = samples
# get optimizers
opt_gen, opt_disc = self.optimizers
# ---------------------
# Train Discriminator
# ---------------------
opt_disc.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(parameters)
# Discriminator pass
d_real = self.discriminator([input_pts, parameters])
d_fake = self.discriminator([gen_imgs.detach(), parameters])
# evaluate loss
d_loss_real = self._loss(d_real, input_pts)
d_loss_fake = self._loss(d_fake, gen_imgs.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
d_loss.backward()
opt_disc.step()
# -----------------
# Train Generator
# -----------------
opt_gen.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(parameters)
# generator loss
r_loss = self._loss(input_pts, gen_imgs)
d_fake = self.discriminator([gen_imgs, parameters])
g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss
# backward step
g_loss.backward()
opt_gen.step()
# ----------------
# Update weights
# ----------------
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
self.k += self.lambda_k * diff.item()
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
# logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
else:
if not hasattr(condition, 'output_points'):
raise NotImplementedError('GAROM works only in data-driven mode.')
# get data
snapshots = out[condition_idx == condition_id]
parameters = pts[condition_idx == condition_id]
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
parameters, snapshots)
r_loss, g_loss = self._train_generator(parameters, snapshots)
diff = self._update_weights(d_loss_real, d_loss_fake)
# logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
return
@property

View File

@@ -97,6 +97,15 @@ class PINN(SolverInterface):
"""
return self.optimizers, [self.scheduler]
def _loss_data(self, input, output):
return self.loss(self.forward(input), output)
def _loss_phys(self, samples, equation):
residual = equation.residual(samples, self.forward(samples))
return self.loss(torch.zeros_like(residual, requires_grad=True), residual)
def training_step(self, batch, batch_idx):
"""PINN solver training step.
@@ -108,25 +117,29 @@ class PINN(SolverInterface):
:rtype: LabelTensor
"""
dataloader = self.trainer.train_dataloader
condition_losses = []
condition_names = []
for condition_name, samples in batch.items():
condition_idx = batch['condition']
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
condition_names.append(condition_name)
condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
# PINN loss: equation evaluated on location or input_points
if hasattr(condition, 'equation'):
target = condition.equation.residual(samples, self.forward(samples))
loss = self.loss(torch.zeros_like(target), target)
# PINN loss: evaluate model(input_points) vs output_points
elif hasattr(condition, 'output_points'):
input_pts, output_pts = samples
loss = self.loss(self.forward(input_pts), output_pts)
if len(batch) == 2:
samples = pts[condition_idx == condition_id]
loss = self._loss_phys(pts, condition.equation)
elif len(batch) == 3:
samples = pts[condition_idx == condition_id]
ground_truth = batch['output'][condition_idx == condition_id]
loss = self._loss_data(samples, ground_truth)
else:
raise ValueError("Batch size not supported")
loss = loss.as_subclass(torch.Tensor)
loss = loss
condition_losses.append(loss * condition.data_weight)
@@ -135,8 +148,8 @@ class PINN(SolverInterface):
total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
for condition_loss, loss in zip(condition_names, condition_losses):
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
# for condition_loss, loss in zip(condition_names, condition_losses):
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss
@property

View File

@@ -2,7 +2,7 @@
from abc import ABCMeta, abstractmethod
from ..model.network import Network
import lightning.pytorch as pl
import pytorch_lightning as pl
from ..utils import check_consistency
from ..problem import AbstractProblem
import torch

View File

@@ -1,18 +1,19 @@
""" Solver module. """
import lightning.pytorch as pl
from pytorch_lightning import Trainer
from .utils import check_consistency
from .dataset import DummyLoader
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from .solvers.solver import SolverInterface
class Trainer(pl.Trainer):
class Trainer(Trainer):
def __init__(self, solver, **kwargs):
def __init__(self, solver, batch_size=None, **kwargs):
super().__init__(**kwargs)
# check inheritance consistency for solver
check_consistency(solver, SolverInterface)
self._model = solver
self.batch_size = batch_size
# create dataloader
if solver.problem.have_sampled_points is False:
@@ -22,19 +23,31 @@ class Trainer(pl.Trainer):
'discretise_domain function before train '
'in the provided locations.')
# TODO: make a better dataloader for train
self._create_or_update_loader()
# this method is used here because is resampling is needed
# during training, there is no need to define to touch the
# trainer dataloader, just call the method.
def _create_or_update_loader(self):
# get accellerator
device = self._accelerator_connector._accelerator_flag
self._loader = DummyLoader(self._model.problem.input_pts, device)
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
devices = self._accelerator_connector._parallel_devices
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
return super().fit(self._model, self._loader, **kwargs)
if len(devices) > 1:
raise RuntimeError('Parallel training is not supported yet.')
device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device)
dataset_data = DataPointDataset(self._model.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size,
shuffle=True)
def train(self, **kwargs):
"""
Train the solver.
"""
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)
@property
def solver(self):

View File

@@ -15,7 +15,7 @@ VERSION = meta['__version__']
KEYWORDS = 'physics-informed neural-network'
REQUIRED = [
'numpy', 'matplotlib', 'torch', 'lightning'
'numpy', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning'
]
EXTRAS = {

View File

@@ -44,9 +44,9 @@ class Poisson(SpatialProblem):
'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_)
# 'data': Condition(
# input_points=in_,
# output_points=out_)
}

View File

@@ -44,9 +44,9 @@ class Poisson(SpatialProblem):
'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_)
# 'data': Condition(
# input_points=in_,
# output_points=out_)
}

122
tests/test_dataset.py Normal file
View File

@@ -0,0 +1,122 @@
import torch
import pytest
from pina.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from pina import LabelTensor, Condition
from pina.equation import Equation
from pina.geometry import CartesianDomain
from pina.problem import SpatialProblem
from pina.model import FeedForward
from pina.operators import laplacian
from pina.equation.equation_factory import FixedValue
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
conditions = {
'gamma1': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 1}),
equation=FixedValue(0.0)),
'gamma2': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 0}),
equation=FixedValue(0.0)),
'gamma3': Condition(
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'gamma4': Condition(
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_),
'data2': Condition(
input_points=in2_,
output_points=out2_)
}
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
poisson = Poisson()
poisson.discretise_domain(10, 'grid', locations=boundaries)
def test_sample():
sample_dataset = SamplePointDataset(poisson, device='cpu')
assert len(sample_dataset) == 140
assert sample_dataset.pts.shape == (140, 2)
assert sample_dataset.pts.labels == ['x', 'y']
assert sample_dataset.condition_indeces.dtype == torch.int64
assert sample_dataset.condition_indeces.max() == torch.tensor(4)
assert sample_dataset.condition_indeces.min() == torch.tensor(0)
def test_data():
dataset = DataPointDataset(poisson, device='cpu')
assert len(dataset) == 61
assert dataset.input_pts.shape == (61, 2)
assert dataset.input_pts.labels == ['x', 'y']
assert dataset.output_pts.shape == (61, 1 )
assert dataset.output_pts.labels == ['u']
assert dataset.condition_indeces.dtype == torch.int64
assert dataset.condition_indeces.max() == torch.tensor(1)
assert dataset.condition_indeces.min() == torch.tensor(0)
def test_loader():
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) in [2, 3]
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']
loader2 = SamplePointLoader(sample_dataset, data_dataset, batch_size=None)
assert len(list(loader2)) == 2
def test_loader2():
poisson2 = Poisson()
del poisson.conditions['data2']
del poisson2.conditions['data']
poisson2.discretise_domain(10, 'grid', locations=boundaries)
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) == 2 # only phys condtions
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']
def test_loader3():
poisson2 = Poisson()
del poisson.conditions['gamma1']
del poisson.conditions['gamma2']
del poisson.conditions['gamma3']
del poisson.conditions['gamma4']
del poisson.conditions['D']
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) == 2 # only phys condtions
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']

View File

@@ -95,10 +95,14 @@ def test_getitem():
def test_getitem2():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5]
assert tensor_view.labels == labels
assert torch.allclose(tensor_view, data[:5])
idx = torch.randperm(tensor.shape[0])
tensor_view = tensor[idx]
assert tensor_view.labels == labels
def test_slice():
tensor = LabelTensor(data, labels)
tensor_view = tensor[:5, :2]

View File

@@ -134,7 +134,7 @@ def test_train_cpu():
hidden_dimension=64)
)
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu')
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu', batch_size=20)
trainer.train()
def test_sample():

View File

@@ -22,6 +22,8 @@ def laplace_equation(input_, output_):
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
@@ -45,7 +47,10 @@ class Poisson(SpatialProblem):
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_)
output_points=out_),
'data2': Condition(
input_points=in2_,
output_points=out2_)
}
def poisson_sol(self, pts):
@@ -92,7 +97,7 @@ def test_train_cpu():
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cpu', batch_size=20)
trainer.train()
def test_train_restore():
@@ -106,7 +111,7 @@ def test_train_restore():
trainer.train()
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
t = ntrainer.train(
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
import shutil
shutil.rmtree(tmpdir)
@@ -121,7 +126,7 @@ def test_train_load():
default_root_dir=tmpdir)
trainer.train()
new_pinn = PINN.load_from_checkpoint(
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
problem = poisson_problem, model=model)
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)