Documentation for v0.1 version (#199)
* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
3f9305d475
commit
8b7b61b3bd
@@ -27,7 +27,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
# put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
"""
|
||||
@@ -55,10 +55,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
def domain(self):
|
||||
"""
|
||||
The domain(s) where the conditions of the AbstractProblem are valid.
|
||||
If more than one domain type is passed, a list of Location is
|
||||
retured.
|
||||
|
||||
:return: the domain(s) of self
|
||||
:rtype: list (if more than one domain are defined),
|
||||
`Span` domain (of only one domain is defined)
|
||||
:return: the domain(s) of ``self``
|
||||
:rtype: list[Location]
|
||||
"""
|
||||
domains = [
|
||||
getattr(self, f'{t}_domain')
|
||||
@@ -109,7 +110,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self.input_pts[condition_name] = samples
|
||||
self._have_sampled_points[condition_name] = True
|
||||
|
||||
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
|
||||
def discretise_domain(self,
|
||||
n,
|
||||
mode='random',
|
||||
variables='all',
|
||||
locations='all'):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
@@ -122,9 +127,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
latin hypercube sampling, ``latin`` or ``lh``;
|
||||
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
||||
:param variables: problem's variables to be sampled, defaults to 'all'.
|
||||
:type variables: str or list[str], optional
|
||||
:type variables: str | list[str]
|
||||
:param locations: problem's locations from where to sample, defaults to 'all'.
|
||||
:type locations: str, optional
|
||||
:type locations: str
|
||||
|
||||
:Example:
|
||||
>>> pinn.discretise_domain(n=10, mode='grid')
|
||||
@@ -146,24 +151,24 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
check_consistency(mode, str)
|
||||
if mode not in ['random', 'grid', 'lh', 'chebyshev', 'latin']:
|
||||
raise TypeError(f'mode {mode} not valid.')
|
||||
|
||||
|
||||
# check consistency variables
|
||||
if variables == 'all':
|
||||
variables = self.input_variables
|
||||
else:
|
||||
check_consistency(variables, str)
|
||||
|
||||
if sorted(variables) != sorted(self.input_variables):
|
||||
|
||||
if sorted(variables) != sorted(self.input_variables):
|
||||
TypeError(f'Wrong variables for sampling. Variables ',
|
||||
f'should be in {self.input_variables}.')
|
||||
|
||||
|
||||
# check consistency location
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.conditions]
|
||||
else:
|
||||
check_consistency(locations, str)
|
||||
|
||||
if sorted(locations) != sorted(self.conditions):
|
||||
if sorted(locations) != sorted(self.conditions):
|
||||
TypeError(f'Wrong locations for sampling. Location ',
|
||||
f'should be in {self.conditions}.')
|
||||
|
||||
@@ -174,7 +179,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# we try to check if we have already sampled
|
||||
try:
|
||||
already_sampled = [self.input_pts[location]]
|
||||
# if we have not sampled, a key error is thrown
|
||||
# if we have not sampled, a key error is thrown
|
||||
except KeyError:
|
||||
already_sampled = []
|
||||
|
||||
@@ -187,16 +192,15 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self._have_sampled_points[location] = False
|
||||
|
||||
# build samples
|
||||
samples = [condition.location.sample(
|
||||
n=n,
|
||||
mode=mode,
|
||||
variables=variables)
|
||||
] + already_sampled
|
||||
samples = [
|
||||
condition.location.sample(n=n, mode=mode, variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
self.input_pts[location] = pts
|
||||
|
||||
# the condition is sampled if input_pts contains all labels
|
||||
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
|
||||
if sorted(self.input_pts[location].labels) == sorted(
|
||||
self.input_variables):
|
||||
self._have_sampled_points[location] = True
|
||||
|
||||
def add_points(self, new_points):
|
||||
@@ -207,21 +211,22 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
and values the torch.Tensor points.
|
||||
"""
|
||||
|
||||
if sorted(new_points.keys()) != sorted(self.conditions):
|
||||
if sorted(new_points.keys()) != sorted(self.conditions):
|
||||
TypeError(f'Wrong locations for new points. Location ',
|
||||
f'should be in {self.conditions}.')
|
||||
|
||||
|
||||
for location in new_points.keys():
|
||||
# extract old and new points
|
||||
old_pts = self.input_pts[location]
|
||||
new_pts = new_points[location]
|
||||
|
||||
# if they don't have the same variables error
|
||||
if sorted(old_pts.labels) != sorted(new_pts.labels):
|
||||
if sorted(old_pts.labels) != sorted(new_pts.labels):
|
||||
TypeError(f'Not matching variables for old and new points '
|
||||
f'in condition {location}.')
|
||||
if old_pts.labels != new_pts.labels:
|
||||
new_pts = torch.hstack([new_pts.extract([i]) for i in old_pts.labels])
|
||||
new_pts = torch.hstack(
|
||||
[new_pts.extract([i]) for i in old_pts.labels])
|
||||
new_pts.labels = old_pts.labels
|
||||
|
||||
# merging
|
||||
@@ -233,13 +238,14 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
def have_sampled_points(self):
|
||||
"""
|
||||
Check if all points for
|
||||
``'Location'`` are sampled.
|
||||
"""
|
||||
``Location`` are sampled.
|
||||
"""
|
||||
return all(self._have_sampled_points.values())
|
||||
|
||||
|
||||
@property
|
||||
def not_sampled_points(self):
|
||||
"""Check which points are
|
||||
"""
|
||||
Check which points are
|
||||
not sampled.
|
||||
"""
|
||||
# variables which are not sampled
|
||||
@@ -251,4 +257,3 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
if not is_sample:
|
||||
not_sampled.append(condition_name)
|
||||
return not_sampled
|
||||
|
||||
|
||||
Reference in New Issue
Block a user