add dataset and dataloader for sample points (#195)
* add dataset and dataloader for sample points * unittests
This commit is contained in:
@@ -15,4 +15,7 @@ from .trainer import Trainer
|
||||
from .plotter import Plotter
|
||||
from .condition import Condition
|
||||
from .geometry import Location
|
||||
from .geometry import CartesianDomain
|
||||
from .geometry import CartesianDomain
|
||||
|
||||
from .dataset import SamplePointDataset
|
||||
from .dataset import SamplePointLoader
|
||||
@@ -1,6 +1,7 @@
|
||||
'''PINA Callbacks Implementations'''
|
||||
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
# from lightning.pytorch.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
import torch
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'''PINA Callbacks Implementations'''
|
||||
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
import torch
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
288
pina/dataset.py
288
pina/dataset.py
@@ -1,78 +1,240 @@
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import functools
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from pina import LabelTensor
|
||||
|
||||
|
||||
class PinaDataset():
|
||||
class SamplePointDataset(Dataset):
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
|
||||
def __init__(self, pinn) -> None:
|
||||
self.pinn = pinn
|
||||
|
||||
@property
|
||||
def dataloader(self):
|
||||
return self._create_dataloader()
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return [self.SampleDataset(key, val)
|
||||
for key, val in self.input_pts.items()]
|
||||
|
||||
def _create_dataloader(self):
|
||||
"""Private method for creating dataloader
|
||||
|
||||
:return: dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
def __init__(self, problem, device) -> None:
|
||||
"""
|
||||
if self.pinn.batch_size is None:
|
||||
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()}
|
||||
:param dict input_pts: The input points.
|
||||
"""
|
||||
super().__init__()
|
||||
pts_list = []
|
||||
self.condition_names = []
|
||||
|
||||
def custom_collate(batch):
|
||||
# extracting pts labels
|
||||
_, pts = list(batch[0].items())[0]
|
||||
labels = pts.labels
|
||||
# calling default torch collate
|
||||
collate_res = default_collate(batch)
|
||||
# save collate result in dict
|
||||
res = {}
|
||||
for key, val in collate_res.items():
|
||||
val.labels = labels
|
||||
res[key] = val
|
||||
def __getitem__(self, index):
|
||||
tensor = self._tensor.select(0, index)
|
||||
return {self._location: tensor}
|
||||
for name, condition in problem.conditions.items():
|
||||
if not hasattr(condition, 'output_points'):
|
||||
pts_list.append(problem.input_pts[name])
|
||||
self.condition_names.append(name)
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
self.pts = LabelTensor.vstack(pts_list)
|
||||
|
||||
if self.pts != []:
|
||||
self.condition_indeces = torch.cat([
|
||||
torch.tensor([i]*len(pts_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
], dim=0)
|
||||
else: # if there are no sample points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.pts = torch.tensor([])
|
||||
|
||||
self.pts = self.pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.pts.shape[0]
|
||||
|
||||
|
||||
class DataPointDataset(Dataset):
|
||||
|
||||
def __init__(self, problem, device) -> None:
|
||||
super().__init__()
|
||||
input_list = []
|
||||
output_list = []
|
||||
self.condition_names = []
|
||||
|
||||
for name, condition in problem.conditions.items():
|
||||
if hasattr(condition, 'output_points'):
|
||||
input_list.append(problem.conditions[name].input_points)
|
||||
output_list.append(problem.conditions[name].output_points)
|
||||
self.condition_names.append(name)
|
||||
|
||||
self.input_pts = LabelTensor.vstack(input_list)
|
||||
self.output_pts = LabelTensor.vstack(output_list)
|
||||
|
||||
if self.input_pts != []:
|
||||
self.condition_indeces = torch.cat([
|
||||
torch.tensor([i]*len(input_list[i]))
|
||||
for i in range(len(self.condition_names))
|
||||
], dim=0)
|
||||
else: # if there are no data points
|
||||
self.condition_indeces = torch.tensor([])
|
||||
self.input_pts = torch.tensor([])
|
||||
self.output_pts = torch.tensor([])
|
||||
|
||||
self.input_pts = self.input_pts.to(device)
|
||||
self.output_pts = self.output_pts.to(device)
|
||||
self.condition_indeces = self.condition_indeces.to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.input_pts.shape[0]
|
||||
|
||||
|
||||
class SamplePointLoader:
|
||||
"""
|
||||
This class is used to create a dataloader to use during the training.
|
||||
|
||||
# TODO: working also for datapoints
|
||||
class DummyLoader:
|
||||
:var condition_names: The names of the conditions. The order is consistent
|
||||
with the condition indeces in the batches.
|
||||
:vartype condition_names: list[str]
|
||||
"""
|
||||
|
||||
def __init__(self, data, device) -> None:
|
||||
def __init__(self, sample_dataset, data_dataset, batch_size=None, shuffle=True) -> None:
|
||||
"""
|
||||
Constructor.
|
||||
|
||||
# TODO: We need to make a dataset somehow
|
||||
# and the PINADataset needs to have a method
|
||||
# to send points to device
|
||||
# now we simply do it here
|
||||
# send data to device
|
||||
def convert_tensors(pts, device):
|
||||
pts = pts.to(device)
|
||||
pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
return pts
|
||||
:param SamplePointDataset sample_pts: The sample points dataset.
|
||||
:param int batch_size: The batch size. If ``None``, the batch size is
|
||||
set to the number of sample points. Default is ``None``.
|
||||
:param bool shuffle: If ``True``, the sample points are shuffled.
|
||||
Default is ``True``.
|
||||
"""
|
||||
if not isinstance(sample_dataset, SamplePointDataset):
|
||||
raise TypeError(f'Expected SamplePointDataset, got {type(sample_dataset)}')
|
||||
if not isinstance(data_dataset, DataPointDataset):
|
||||
raise TypeError(f'Expected DataPointDataset, got {type(data_dataset)}')
|
||||
|
||||
for location, pts in data.items():
|
||||
if isinstance(pts, (tuple, list)):
|
||||
pts = tuple(map(functools.partial(convert_tensors, device=device),pts))
|
||||
else:
|
||||
pts = pts.to(device)
|
||||
pts = pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
|
||||
data[location] = pts
|
||||
self.n_data_conditions = len(data_dataset.condition_names)
|
||||
self.n_phys_conditions = len(sample_dataset.condition_names)
|
||||
data_dataset.condition_indeces += self.n_phys_conditions
|
||||
|
||||
# iterator
|
||||
self.data = [data]
|
||||
self._prepare_sample_dataset(sample_dataset, batch_size, shuffle)
|
||||
self._prepare_data_dataset(data_dataset, batch_size, shuffle)
|
||||
|
||||
self.condition_names = (
|
||||
sample_dataset.condition_names + data_dataset.condition_names)
|
||||
|
||||
self.batch_list = []
|
||||
for i in range(len(self.batch_sample_pts)):
|
||||
self.batch_list.append(
|
||||
('sample', i)
|
||||
)
|
||||
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_list.append(
|
||||
('data', i)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
self.random_idx = torch.randperm(len(self.batch_list))
|
||||
else:
|
||||
self.random_idx = torch.arange(len(self.batch_list))
|
||||
|
||||
|
||||
def _prepare_data_dataset(self, dataset, batch_size, shuffle):
|
||||
"""
|
||||
Prepare the dataset for data points.
|
||||
|
||||
:param SamplePointDataset dataset: The dataset.
|
||||
:param int batch_size: The batch size.
|
||||
:param bool shuffle: If ``True``, the sample points are shuffled.
|
||||
"""
|
||||
self.sample_dataset = dataset
|
||||
|
||||
if len(dataset) == 0:
|
||||
self.batch_data_conditions = []
|
||||
self.batch_input_pts = []
|
||||
self.batch_output_pts = []
|
||||
return
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = len(dataset)
|
||||
batch_num = len(dataset) // batch_size
|
||||
if len(dataset) % batch_size != 0:
|
||||
batch_num += 1
|
||||
|
||||
output_labels = dataset.output_pts.labels
|
||||
input_labels = dataset.input_pts.labels
|
||||
self.tensor_conditions = dataset.condition_indeces
|
||||
|
||||
if shuffle:
|
||||
idx = torch.randperm(dataset.input_pts.shape[0])
|
||||
self.input_pts = dataset.input_pts[idx]
|
||||
self.output_pts = dataset.output_pts[idx]
|
||||
self.tensor_conditions = dataset.condition_indeces[idx]
|
||||
|
||||
self.batch_input_pts = torch.tensor_split(
|
||||
dataset.input_pts, batch_num)
|
||||
self.batch_output_pts = torch.tensor_split(
|
||||
dataset.output_pts, batch_num)
|
||||
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_input_pts[i].labels = input_labels
|
||||
self.batch_output_pts[i].labels = output_labels
|
||||
|
||||
self.batch_data_conditions = torch.tensor_split(
|
||||
self.tensor_conditions, batch_num)
|
||||
|
||||
def _prepare_sample_dataset(self, dataset, batch_size, shuffle):
|
||||
"""
|
||||
Prepare the dataset for sample points.
|
||||
|
||||
:param DataPointDataset dataset: The dataset.
|
||||
:param int batch_size: The batch size.
|
||||
:param bool shuffle: If ``True``, the sample points are shuffled.
|
||||
"""
|
||||
|
||||
self.sample_dataset = dataset
|
||||
if len(dataset) == 0:
|
||||
self.batch_sample_conditions = []
|
||||
self.batch_sample_pts = []
|
||||
return
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = len(dataset)
|
||||
|
||||
batch_num = len(dataset) // batch_size
|
||||
if len(dataset) % batch_size != 0:
|
||||
batch_num += 1
|
||||
|
||||
self.tensor_pts = dataset.pts
|
||||
self.tensor_conditions = dataset.condition_indeces
|
||||
|
||||
# if shuffle:
|
||||
# idx = torch.randperm(self.tensor_pts.shape[0])
|
||||
# self.tensor_pts = self.tensor_pts[idx]
|
||||
# self.tensor_conditions = self.tensor_conditions[idx]
|
||||
|
||||
self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num)
|
||||
for i in range(len(self.batch_sample_pts)):
|
||||
self.batch_sample_pts[i].labels = dataset.pts.labels
|
||||
|
||||
self.batch_sample_conditions = torch.tensor_split(
|
||||
self.tensor_conditions, batch_num)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
"""
|
||||
Return an iterator over the points. Any element of the iterator is a
|
||||
dictionary with the following keys:
|
||||
- ``pts``: The input sample points. It is a LabelTensor with the
|
||||
shape ``(batch_size, input_dimension)``.
|
||||
- ``output``: The output sample points. This key is present only
|
||||
if data conditions are present. It is a LabelTensor with the
|
||||
shape ``(batch_size, output_dimension)``.
|
||||
- ``condition``: The integer condition indeces. It is a tensor
|
||||
with the shape ``(batch_size, )`` of type ``torch.int64`` and
|
||||
indicates for any ``pts`` the corresponding problem condition.
|
||||
|
||||
:return: An iterator over the points.
|
||||
:rtype: iter
|
||||
"""
|
||||
#for i in self.random_idx:
|
||||
for i in range(len(self.batch_list)):
|
||||
type_, idx_ = self.batch_list[i]
|
||||
|
||||
if type_ == 'sample':
|
||||
d = {
|
||||
'pts': self.batch_sample_pts[idx_].requires_grad_(True),
|
||||
'condition': self.batch_sample_conditions[idx_],
|
||||
}
|
||||
else:
|
||||
d = {
|
||||
'pts': self.batch_input_pts[idx_].requires_grad_(True),
|
||||
'output': self.batch_output_pts[idx_],
|
||||
'condition': self.batch_data_conditions[idx_],
|
||||
}
|
||||
yield d
|
||||
@@ -55,13 +55,15 @@ class SimplexDomain(Location):
|
||||
raise ValueError("An n-dimensional simplex is composed by n + 1 tensors of dimension n.")
|
||||
|
||||
# creating vertices matrix
|
||||
self._vertices_matrix = torch.cat(simplex_matrix)
|
||||
self._vertices_matrix.labels = matrix_labels
|
||||
self._vertices_matrix = LabelTensor.vstack(simplex_matrix)
|
||||
|
||||
# creating basis vectors for simplex
|
||||
self._vectors_shifted = (
|
||||
(self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
|
||||
)
|
||||
# self._vectors_shifted = (
|
||||
# (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
|
||||
# ) ### TODO: Remove after checking
|
||||
|
||||
vert = self._vertices_matrix
|
||||
self._vectors_shifted = (vert[:-1] - vert[-1]).T
|
||||
|
||||
# build cartesian_bound
|
||||
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
||||
@@ -114,8 +116,8 @@ class SimplexDomain(Location):
|
||||
f" expected {self.variables}."
|
||||
)
|
||||
|
||||
# shift point
|
||||
point_shift = point.T - (self._vertices_matrix.T)[:, None, -1]
|
||||
point_shift = point - self._vertices_matrix[-1]
|
||||
point_shift = point_shift.tensor.reshape(-1, 1)
|
||||
|
||||
# compute barycentric coordinates
|
||||
lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0)
|
||||
|
||||
@@ -96,6 +96,28 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
self._labels = labels # assign the label
|
||||
|
||||
@staticmethod
|
||||
def vstack(label_tensors):
|
||||
"""
|
||||
Stack tensors vertically. For more details, see
|
||||
:meth:`torch.vstack`.
|
||||
|
||||
:param list(LabelTensor) label_tensors: the tensors to stack. They need
|
||||
to have equal labels.
|
||||
:return: the stacked tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if len(label_tensors) == 0:
|
||||
return []
|
||||
|
||||
all_labels = [label for lt in label_tensors for label in lt.labels]
|
||||
if set(all_labels) != set(label_tensors[0].labels):
|
||||
raise RuntimeError('The tensors to stack have different labels')
|
||||
|
||||
labels = label_tensors[0].labels
|
||||
tensors = [lt.extract(labels) for lt in label_tensors]
|
||||
return LabelTensor(torch.vstack(tensors), labels)
|
||||
|
||||
# TODO remove try/ except thing IMPORTANT
|
||||
# make the label None of default
|
||||
def clone(self, *args, **kwargs):
|
||||
@@ -183,6 +205,18 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
return extracted_tensor
|
||||
|
||||
def detach(self):
|
||||
detached = super().detach()
|
||||
if hasattr(self, '_labels'):
|
||||
detached._labels = self._labels
|
||||
return detached
|
||||
|
||||
|
||||
def requires_grad_(self, mode = True) -> Tensor:
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self.labels
|
||||
return lt
|
||||
|
||||
def append(self, lt, mode='std'):
|
||||
"""
|
||||
Return a copy of the merged tensors.
|
||||
@@ -232,7 +266,7 @@ class LabelTensor(torch.Tensor):
|
||||
len_index = len(index)
|
||||
except TypeError:
|
||||
len_index = 1
|
||||
|
||||
|
||||
if isinstance(index, int) or len_index == 1:
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(1, -1)
|
||||
@@ -246,8 +280,14 @@ class LabelTensor(torch.Tensor):
|
||||
selected_lt.labels = [self.labels[i] for i in index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels[index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels
|
||||
|
||||
return selected_lt
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
return self.as_subclass(Tensor)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return super().__len__()
|
||||
|
||||
@@ -111,8 +111,9 @@ class LpLoss(LossInterface):
|
||||
|
||||
# check consistency
|
||||
check_consistency(p, (str,int,float))
|
||||
self.p = p
|
||||
check_consistency(relative, bool)
|
||||
|
||||
self.p = p
|
||||
self.relative = relative
|
||||
|
||||
def forward(self, input, target):
|
||||
|
||||
@@ -22,6 +22,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# varible to check if sampling is done. If no location
|
||||
# element is presented in Condition this variable is set to true
|
||||
self._have_sampled_points = {}
|
||||
for condition_name in self.conditions:
|
||||
self._have_sampled_points[condition_name] = False
|
||||
|
||||
# put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
@@ -102,15 +104,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
if hasattr(condition, 'equation') and hasattr(condition, 'input_points'):
|
||||
if hasattr(condition, 'input_points'):
|
||||
samples = condition.input_points
|
||||
elif hasattr(condition, 'output_points') and hasattr(condition, 'input_points'):
|
||||
samples = (condition.input_points, condition.output_points)
|
||||
# skip if we need to sample
|
||||
elif hasattr(condition, 'location'):
|
||||
self._have_sampled_points[condition_name] = False
|
||||
continue
|
||||
self.input_pts[condition_name] = samples
|
||||
self.input_pts[condition_name] = samples
|
||||
self._have_sampled_points[condition_name] = True
|
||||
|
||||
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
|
||||
"""
|
||||
@@ -204,7 +201,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
Adding points to the already sampled points
|
||||
Adding points to the already sampled points.
|
||||
|
||||
:param dict new_points: a dictionary with key the location to add the points
|
||||
and values the torch.Tensor points.
|
||||
|
||||
@@ -115,12 +115,15 @@ class GAROM(SolverInterface):
|
||||
check_consistency(lambda_k, float)
|
||||
check_consistency(regularizer, bool)
|
||||
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [scheduler_generator(self.optimizers[0],
|
||||
**scheduler_generator_kwargs),
|
||||
scheduler_discriminator(self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)]
|
||||
self._schedulers = [
|
||||
scheduler_generator(
|
||||
self.optimizers[0], **scheduler_generator_kwargs),
|
||||
scheduler_discriminator(
|
||||
self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)
|
||||
]
|
||||
|
||||
# loss and writer
|
||||
self._loss = loss
|
||||
|
||||
@@ -157,6 +160,63 @@ class GAROM(SolverInterface):
|
||||
def sample(self, x):
|
||||
# sampling
|
||||
return self.generator(x)
|
||||
|
||||
def _train_generator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the generator network.
|
||||
"""
|
||||
optimizer = self.optimizer_generator
|
||||
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(snapshots, generated_snapshots)
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return r_loss, g_loss
|
||||
|
||||
def _train_discriminator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the discriminator network.
|
||||
"""
|
||||
optimizer = self.optimizer_discriminator
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([snapshots, parameters])
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, snapshots)
|
||||
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
return d_loss_real, d_loss_fake, d_loss
|
||||
|
||||
def _update_weights(self, d_loss_real, d_loss_fake):
|
||||
"""
|
||||
Private method to Update the weights of the generator and discriminator
|
||||
networks.
|
||||
"""
|
||||
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
|
||||
# Update weight term for fake samples
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
return diff
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
@@ -169,78 +229,40 @@ class GAROM(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_idx = batch['condition']
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
out = batch['output']
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
# for data driven mode
|
||||
if hasattr(condition, 'output_points'):
|
||||
|
||||
# get data
|
||||
parameters, input_pts = samples
|
||||
|
||||
# get optimizers
|
||||
opt_gen, opt_disc = self.optimizers
|
||||
|
||||
# ---------------------
|
||||
# Train Discriminator
|
||||
# ---------------------
|
||||
opt_disc.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([input_pts, parameters])
|
||||
d_fake = self.discriminator([gen_imgs.detach(), parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, input_pts)
|
||||
d_loss_fake = self._loss(d_fake, gen_imgs.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward()
|
||||
opt_disc.step()
|
||||
|
||||
# -----------------
|
||||
# Train Generator
|
||||
# -----------------
|
||||
opt_gen.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(input_pts, gen_imgs)
|
||||
d_fake = self.discriminator([gen_imgs, parameters])
|
||||
g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# ----------------
|
||||
# Update weights
|
||||
# ----------------
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
|
||||
# Update weight term for fake samples
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
||||
|
||||
else:
|
||||
if not hasattr(condition, 'output_points'):
|
||||
raise NotImplementedError('GAROM works only in data-driven mode.')
|
||||
|
||||
# get data
|
||||
snapshots = out[condition_idx == condition_id]
|
||||
parameters = pts[condition_idx == condition_id]
|
||||
|
||||
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
|
||||
parameters, snapshots)
|
||||
|
||||
r_loss, g_loss = self._train_generator(parameters, snapshots)
|
||||
|
||||
diff = self._update_weights(d_loss_real, d_loss_fake)
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
||||
|
||||
return
|
||||
|
||||
@property
|
||||
|
||||
@@ -97,6 +97,15 @@ class PINN(SolverInterface):
|
||||
"""
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
def _loss_data(self, input, output):
|
||||
return self.loss(self.forward(input), output)
|
||||
|
||||
|
||||
def _loss_phys(self, samples, equation):
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
return self.loss(torch.zeros_like(residual, requires_grad=True), residual)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
|
||||
@@ -108,25 +117,29 @@ class PINN(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_losses = []
|
||||
condition_names = []
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
condition_idx = batch['condition']
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
|
||||
condition_names.append(condition_name)
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
|
||||
# PINN loss: equation evaluated on location or input_points
|
||||
if hasattr(condition, 'equation'):
|
||||
target = condition.equation.residual(samples, self.forward(samples))
|
||||
loss = self.loss(torch.zeros_like(target), target)
|
||||
# PINN loss: evaluate model(input_points) vs output_points
|
||||
elif hasattr(condition, 'output_points'):
|
||||
input_pts, output_pts = samples
|
||||
loss = self.loss(self.forward(input_pts), output_pts)
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self._loss_phys(pts, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch['output'][condition_idx == condition_id]
|
||||
loss = self._loss_data(samples, ground_truth)
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
loss = loss
|
||||
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
|
||||
@@ -135,8 +148,8 @@ class PINN(SolverInterface):
|
||||
total_loss = sum(condition_losses)
|
||||
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
|
||||
for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
# for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import lightning.pytorch as pl
|
||||
import pytorch_lightning as pl
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
import torch
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
""" Solver module. """
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from .utils import check_consistency
|
||||
from .dataset import DummyLoader
|
||||
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pl.Trainer):
|
||||
class Trainer(Trainer):
|
||||
|
||||
def __init__(self, solver, **kwargs):
|
||||
def __init__(self, solver, batch_size=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# check inheritance consistency for solver
|
||||
check_consistency(solver, SolverInterface)
|
||||
self._model = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
# create dataloader
|
||||
if solver.problem.have_sampled_points is False:
|
||||
@@ -22,19 +23,31 @@ class Trainer(pl.Trainer):
|
||||
'discretise_domain function before train '
|
||||
'in the provided locations.')
|
||||
|
||||
# TODO: make a better dataloader for train
|
||||
self._create_or_update_loader()
|
||||
|
||||
# this method is used here because is resampling is needed
|
||||
# during training, there is no need to define to touch the
|
||||
# trainer dataloader, just call the method.
|
||||
def _create_or_update_loader(self):
|
||||
# get accellerator
|
||||
device = self._accelerator_connector._accelerator_flag
|
||||
self._loader = DummyLoader(self._model.problem.input_pts, device)
|
||||
"""
|
||||
This method is used here because is resampling is needed
|
||||
during training, there is no need to define to touch the
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
|
||||
return super().fit(self._model, self._loader, **kwargs)
|
||||
if len(devices) > 1:
|
||||
raise RuntimeError('Parallel training is not supported yet.')
|
||||
|
||||
device = devices[0]
|
||||
dataset_phys = SamplePointDataset(self._model.problem, device)
|
||||
dataset_data = DataPointDataset(self._model.problem, device)
|
||||
self._loader = SamplePointLoader(
|
||||
dataset_phys, dataset_data, batch_size=self.batch_size,
|
||||
shuffle=True)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Train the solver.
|
||||
"""
|
||||
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
|
||||
Reference in New Issue
Block a user