Adaptive Refinment and Multiple Optimizer callbacks
* Implementing a callback to switch between optimizers during training * Implementing the R3Refinment for collocation points * Modify trainer -> dataloader is created or updated by calling `_create_or_update_loader` * Adding `add_points` routine to AbstractProblem so that new points can be added without resampling from scratch
This commit is contained in:
committed by
Nicola Demo
parent
5a4c114d48
commit
4d1187898f
@@ -1,6 +1,7 @@
|
||||
""" Module for AbstractProblem class """
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..utils import merge_tensors, check_consistency
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
@@ -201,6 +202,36 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
|
||||
self._have_sampled_points[location] = True
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
Adding points to the already sampled points
|
||||
|
||||
:param dict new_points: a dictionary with key the location to add the points
|
||||
and values the torch.Tensor points.
|
||||
"""
|
||||
|
||||
if sorted(new_points.keys()) != sorted(self.conditions):
|
||||
TypeError(f'Wrong locations for new points. Location ',
|
||||
f'should be in {self.conditions}.')
|
||||
|
||||
for location in new_points.keys():
|
||||
# extract old and new points
|
||||
old_pts = self.input_pts[location]
|
||||
new_pts = new_points[location]
|
||||
|
||||
# if they don't have the same variables error
|
||||
if sorted(old_pts.labels) != sorted(new_pts.labels):
|
||||
TypeError(f'Not matching variables for old and new points '
|
||||
f'in condition {location}.')
|
||||
if old_pts.labels != new_pts.labels:
|
||||
new_pts = torch.hstack([new_pts.extract([i]) for i in old_pts.labels])
|
||||
new_pts.labels = old_pts.labels
|
||||
|
||||
# merging
|
||||
merged_pts = torch.vstack([old_pts, new_points[location]])
|
||||
merged_pts.labels = old_pts.labels
|
||||
self.input_pts[location] = merged_pts
|
||||
|
||||
@property
|
||||
def have_sampled_points(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user