CPU/GPU/TPU training (#159)

* device training

---------

Co-authored-by: Dario Coscia <dcoscia@lovelace.maths.sissa.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
Dario Coscia
2023-07-19 17:19:08 +02:00
committed by Nicola Demo
parent 38ecebd44b
commit 92e0e4920b
4 changed files with 62 additions and 28 deletions

View File

@@ -111,7 +111,7 @@ class AbstractProblem(metaclass=ABCMeta):
continue
self.input_pts[condition_name] = samples
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all', device=None):
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
"""
Generate a set of points to span the `Location` of all the conditions of
the problem.
@@ -129,9 +129,9 @@ class AbstractProblem(metaclass=ABCMeta):
:type locations: str, optional
:Example:
>>> pinn.span_pts(n=10, mode='grid')
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
>>> pinn.discretise_domain(n=10, mode='grid')
>>> pinn.discretise_domain(n=10, mode='grid', location=['bound1'])
>>> pinn.discretise_domain(n=10, mode='grid', variables=['x'])
.. warning::
``random`` is currently the only implemented ``mode`` for all geometries, i.e.
@@ -200,12 +200,6 @@ class AbstractProblem(metaclass=ABCMeta):
# the condition is sampled if input_pts contains all labels
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
self._have_sampled_points[location] = True
# setting device
if device:
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
# setting the grad
self.input_pts[location].requires_grad_(True)
self.input_pts[location].retain_grad()
@property
def have_sampled_points(self):