edited utils to take list (#115)
* enhanced difference domain * refactored utils * fixed typo * added tests --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
10
pina/pinn.py
10
pina/pinn.py
@@ -48,11 +48,11 @@ class PINN(SolverInterface):
|
||||
super().__init__(model=model, problem=problem, extra_features=extra_features)
|
||||
|
||||
# check consistency
|
||||
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
|
||||
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
|
||||
check_consistency(scheduler, LRScheduler, 'scheduler', subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
|
||||
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)
|
||||
check_consistency(optimizer, torch.optim.Optimizer, subclass=True)
|
||||
check_consistency(optimizer_kwargs, dict)
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict)
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user