add dataset and dataloader for sample points (#195)
* add dataset and dataloader for sample points * unittests
This commit is contained in:
@@ -115,12 +115,15 @@ class GAROM(SolverInterface):
|
||||
check_consistency(lambda_k, float)
|
||||
check_consistency(regularizer, bool)
|
||||
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [scheduler_generator(self.optimizers[0],
|
||||
**scheduler_generator_kwargs),
|
||||
scheduler_discriminator(self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)]
|
||||
self._schedulers = [
|
||||
scheduler_generator(
|
||||
self.optimizers[0], **scheduler_generator_kwargs),
|
||||
scheduler_discriminator(
|
||||
self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)
|
||||
]
|
||||
|
||||
# loss and writer
|
||||
self._loss = loss
|
||||
|
||||
@@ -157,6 +160,63 @@ class GAROM(SolverInterface):
|
||||
def sample(self, x):
|
||||
# sampling
|
||||
return self.generator(x)
|
||||
|
||||
def _train_generator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the generator network.
|
||||
"""
|
||||
optimizer = self.optimizer_generator
|
||||
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(snapshots, generated_snapshots)
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return r_loss, g_loss
|
||||
|
||||
def _train_discriminator(self, parameters, snapshots):
|
||||
"""
|
||||
Private method to train the discriminator network.
|
||||
"""
|
||||
optimizer = self.optimizer_discriminator
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([snapshots, parameters])
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, snapshots)
|
||||
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
return d_loss_real, d_loss_fake, d_loss
|
||||
|
||||
def _update_weights(self, d_loss_real, d_loss_fake):
|
||||
"""
|
||||
Private method to Update the weights of the generator and discriminator
|
||||
networks.
|
||||
"""
|
||||
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
|
||||
# Update weight term for fake samples
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
return diff
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
@@ -169,78 +229,40 @@ class GAROM(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_idx = batch['condition']
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
out = batch['output']
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
# for data driven mode
|
||||
if hasattr(condition, 'output_points'):
|
||||
|
||||
# get data
|
||||
parameters, input_pts = samples
|
||||
|
||||
# get optimizers
|
||||
opt_gen, opt_disc = self.optimizers
|
||||
|
||||
# ---------------------
|
||||
# Train Discriminator
|
||||
# ---------------------
|
||||
opt_disc.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([input_pts, parameters])
|
||||
d_fake = self.discriminator([gen_imgs.detach(), parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, input_pts)
|
||||
d_loss_fake = self._loss(d_fake, gen_imgs.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward()
|
||||
opt_disc.step()
|
||||
|
||||
# -----------------
|
||||
# Train Generator
|
||||
# -----------------
|
||||
opt_gen.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(input_pts, gen_imgs)
|
||||
d_fake = self.discriminator([gen_imgs, parameters])
|
||||
g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# ----------------
|
||||
# Update weights
|
||||
# ----------------
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
|
||||
# Update weight term for fake samples
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
||||
|
||||
else:
|
||||
if not hasattr(condition, 'output_points'):
|
||||
raise NotImplementedError('GAROM works only in data-driven mode.')
|
||||
|
||||
# get data
|
||||
snapshots = out[condition_idx == condition_id]
|
||||
parameters = pts[condition_idx == condition_id]
|
||||
|
||||
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
|
||||
parameters, snapshots)
|
||||
|
||||
r_loss, g_loss = self._train_generator(parameters, snapshots)
|
||||
|
||||
diff = self._update_weights(d_loss_real, d_loss_fake)
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
||||
|
||||
return
|
||||
|
||||
@property
|
||||
|
||||
@@ -97,6 +97,15 @@ class PINN(SolverInterface):
|
||||
"""
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
def _loss_data(self, input, output):
|
||||
return self.loss(self.forward(input), output)
|
||||
|
||||
|
||||
def _loss_phys(self, samples, equation):
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
return self.loss(torch.zeros_like(residual, requires_grad=True), residual)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
|
||||
@@ -108,25 +117,29 @@ class PINN(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_losses = []
|
||||
condition_names = []
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
condition_idx = batch['condition']
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
|
||||
|
||||
condition_names.append(condition_name)
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch['pts']
|
||||
|
||||
# PINN loss: equation evaluated on location or input_points
|
||||
if hasattr(condition, 'equation'):
|
||||
target = condition.equation.residual(samples, self.forward(samples))
|
||||
loss = self.loss(torch.zeros_like(target), target)
|
||||
# PINN loss: evaluate model(input_points) vs output_points
|
||||
elif hasattr(condition, 'output_points'):
|
||||
input_pts, output_pts = samples
|
||||
loss = self.loss(self.forward(input_pts), output_pts)
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self._loss_phys(pts, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch['output'][condition_idx == condition_id]
|
||||
loss = self._loss_data(samples, ground_truth)
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
loss = loss
|
||||
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
|
||||
@@ -135,8 +148,8 @@ class PINN(SolverInterface):
|
||||
total_loss = sum(condition_losses)
|
||||
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
|
||||
for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
# for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import lightning.pytorch as pl
|
||||
import pytorch_lightning as pl
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user