Adaptive Refinment and Multiple Optimizer callbacks

* Implementing a callback to switch between optimizers during training
* Implementing the R3Refinment for collocation points
* Modify trainer -> dataloader is created or updated by calling `_create_or_update_loader`
* Adding `add_points` routine to AbstractProblem so that new points can be added without resampling from scratch
This commit is contained in:
Dario Coscia
2023-09-14 18:37:02 +02:00
committed by Nicola Demo
parent 5a4c114d48
commit 4d1187898f
3 changed files with 229 additions and 4 deletions

View File

@@ -1,6 +1,7 @@
""" Module for AbstractProblem class """
from abc import ABCMeta, abstractmethod
from ..utils import merge_tensors, check_consistency
import torch
class AbstractProblem(metaclass=ABCMeta):
@@ -201,6 +202,36 @@ class AbstractProblem(metaclass=ABCMeta):
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
self._have_sampled_points[location] = True
def add_points(self, new_points):
"""
Adding points to the already sampled points
:param dict new_points: a dictionary with key the location to add the points
and values the torch.Tensor points.
"""
if sorted(new_points.keys()) != sorted(self.conditions):
TypeError(f'Wrong locations for new points. Location ',
f'should be in {self.conditions}.')
for location in new_points.keys():
# extract old and new points
old_pts = self.input_pts[location]
new_pts = new_points[location]
# if they don't have the same variables error
if sorted(old_pts.labels) != sorted(new_pts.labels):
TypeError(f'Not matching variables for old and new points '
f'in condition {location}.')
if old_pts.labels != new_pts.labels:
new_pts = torch.hstack([new_pts.extract([i]) for i in old_pts.labels])
new_pts.labels = old_pts.labels
# merging
merged_pts = torch.vstack([old_pts, new_points[location]])
merged_pts.labels = old_pts.labels
self.input_pts[location] = merged_pts
@property
def have_sampled_points(self):
"""