edited utils to take list (#115)

* enhanced difference domain
* refactored utils
* fixed typo
* added tests
---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
Kush
2023-06-19 18:47:52 +02:00
committed by Nicola Demo
parent aaf2bed732
commit 62ec69ccac
9 changed files with 73 additions and 47 deletions

View File

@@ -48,11 +48,11 @@ class PINN(SolverInterface):
super().__init__(model=model, problem=problem, extra_features=extra_features)
# check consistency
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True)
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)
check_consistency(optimizer, torch.optim.Optimizer, subclass=True)
check_consistency(optimizer_kwargs, dict)
check_consistency(scheduler, LRScheduler, subclass=True)
check_consistency(scheduler_kwargs, dict)
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs)