Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for Physics Informed Neural Network. """
|
||||
"""Module for Physics Informed Neural Network."""
|
||||
|
||||
import torch
|
||||
|
||||
@@ -48,13 +48,15 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
@@ -67,12 +69,14 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
:param torch.nn.Module loss: The loss function to be minimized;
|
||||
default `None`.
|
||||
"""
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
@@ -112,7 +116,4 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
}
|
||||
)
|
||||
self.scheduler.hook(self.optimizer)
|
||||
return (
|
||||
[self.optimizer.instance],
|
||||
[self.scheduler.instance]
|
||||
)
|
||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||
|
||||
Reference in New Issue
Block a user