add dataset and dataloader for sample points (#195)

* add dataset and dataloader for sample points
* unittests
This commit is contained in:
Nicola Demo
2023-11-07 11:34:44 +01:00
parent cd5bc9a558
commit d654259428
19 changed files with 581 additions and 196 deletions

View File

@@ -115,12 +115,15 @@ class GAROM(SolverInterface):
check_consistency(lambda_k, float)
check_consistency(regularizer, bool)
# assign schedulers
self._schedulers = [scheduler_generator(self.optimizers[0],
**scheduler_generator_kwargs),
scheduler_discriminator(self.optimizers[1],
**scheduler_discriminator_kwargs)]
self._schedulers = [
scheduler_generator(
self.optimizers[0], **scheduler_generator_kwargs),
scheduler_discriminator(
self.optimizers[1],
**scheduler_discriminator_kwargs)
]
# loss and writer
self._loss = loss
@@ -157,6 +160,63 @@ class GAROM(SolverInterface):
def sample(self, x):
# sampling
return self.generator(x)
def _train_generator(self, parameters, snapshots):
"""
Private method to train the generator network.
"""
optimizer = self.optimizer_generator
generated_snapshots = self.generator(parameters)
# generator loss
r_loss = self._loss(snapshots, generated_snapshots)
d_fake = self.discriminator([generated_snapshots, parameters])
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
# backward step
g_loss.backward()
optimizer.step()
return r_loss, g_loss
def _train_discriminator(self, parameters, snapshots):
"""
Private method to train the discriminator network.
"""
optimizer = self.optimizer_discriminator
optimizer.zero_grad()
# Generate a batch of images
generated_snapshots = self.generator(parameters)
# Discriminator pass
d_real = self.discriminator([snapshots, parameters])
d_fake = self.discriminator([generated_snapshots, parameters])
# evaluate loss
d_loss_real = self._loss(d_real, snapshots)
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
d_loss.backward(retain_graph=True)
optimizer.step()
return d_loss_real, d_loss_fake, d_loss
def _update_weights(self, d_loss_real, d_loss_fake):
"""
Private method to Update the weights of the generator and discriminator
networks.
"""
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
self.k += self.lambda_k * diff.item()
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
return diff
def training_step(self, batch, batch_idx):
"""PINN solver training step.
@@ -169,78 +229,40 @@ class GAROM(SolverInterface):
:rtype: LabelTensor
"""
for condition_name, samples in batch.items():
dataloader = self.trainer.train_dataloader
condition_idx = batch['condition']
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
out = batch['output']
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition = self.problem.conditions[condition_name]
# for data driven mode
if hasattr(condition, 'output_points'):
# get data
parameters, input_pts = samples
# get optimizers
opt_gen, opt_disc = self.optimizers
# ---------------------
# Train Discriminator
# ---------------------
opt_disc.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(parameters)
# Discriminator pass
d_real = self.discriminator([input_pts, parameters])
d_fake = self.discriminator([gen_imgs.detach(), parameters])
# evaluate loss
d_loss_real = self._loss(d_real, input_pts)
d_loss_fake = self._loss(d_fake, gen_imgs.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
d_loss.backward()
opt_disc.step()
# -----------------
# Train Generator
# -----------------
opt_gen.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(parameters)
# generator loss
r_loss = self._loss(input_pts, gen_imgs)
d_fake = self.discriminator([gen_imgs, parameters])
g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss
# backward step
g_loss.backward()
opt_gen.step()
# ----------------
# Update weights
# ----------------
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
self.k += self.lambda_k * diff.item()
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
# logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
else:
if not hasattr(condition, 'output_points'):
raise NotImplementedError('GAROM works only in data-driven mode.')
# get data
snapshots = out[condition_idx == condition_id]
parameters = pts[condition_idx == condition_id]
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
parameters, snapshots)
r_loss, g_loss = self._train_generator(parameters, snapshots)
diff = self._update_weights(d_loss_real, d_loss_fake)
# logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
return
@property

View File

@@ -97,6 +97,15 @@ class PINN(SolverInterface):
"""
return self.optimizers, [self.scheduler]
def _loss_data(self, input, output):
return self.loss(self.forward(input), output)
def _loss_phys(self, samples, equation):
residual = equation.residual(samples, self.forward(samples))
return self.loss(torch.zeros_like(residual, requires_grad=True), residual)
def training_step(self, batch, batch_idx):
"""PINN solver training step.
@@ -108,25 +117,29 @@ class PINN(SolverInterface):
:rtype: LabelTensor
"""
dataloader = self.trainer.train_dataloader
condition_losses = []
condition_names = []
for condition_name, samples in batch.items():
condition_idx = batch['condition']
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
condition_names.append(condition_name)
condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
# PINN loss: equation evaluated on location or input_points
if hasattr(condition, 'equation'):
target = condition.equation.residual(samples, self.forward(samples))
loss = self.loss(torch.zeros_like(target), target)
# PINN loss: evaluate model(input_points) vs output_points
elif hasattr(condition, 'output_points'):
input_pts, output_pts = samples
loss = self.loss(self.forward(input_pts), output_pts)
if len(batch) == 2:
samples = pts[condition_idx == condition_id]
loss = self._loss_phys(pts, condition.equation)
elif len(batch) == 3:
samples = pts[condition_idx == condition_id]
ground_truth = batch['output'][condition_idx == condition_id]
loss = self._loss_data(samples, ground_truth)
else:
raise ValueError("Batch size not supported")
loss = loss.as_subclass(torch.Tensor)
loss = loss
condition_losses.append(loss * condition.data_weight)
@@ -135,8 +148,8 @@ class PINN(SolverInterface):
total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
for condition_loss, loss in zip(condition_names, condition_losses):
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
# for condition_loss, loss in zip(condition_names, condition_losses):
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss
@property

View File

@@ -2,7 +2,7 @@
from abc import ABCMeta, abstractmethod
from ..model.network import Network
import lightning.pytorch as pl
import pytorch_lightning as pl
from ..utils import check_consistency
from ..problem import AbstractProblem
import torch