🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
""" Module for AbstractProblem class """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..utils import merge_tensors, check_consistency
|
||||
import torch
|
||||
@@ -40,13 +41,13 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
variables = []
|
||||
|
||||
if hasattr(self, 'spatial_variables'):
|
||||
if hasattr(self, "spatial_variables"):
|
||||
variables += self.spatial_variables
|
||||
if hasattr(self, 'temporal_variable'):
|
||||
if hasattr(self, "temporal_variable"):
|
||||
variables += self.temporal_variable
|
||||
if hasattr(self, 'parameters'):
|
||||
if hasattr(self, "parameters"):
|
||||
variables += self.parameters
|
||||
if hasattr(self, 'custom_variables'):
|
||||
if hasattr(self, "custom_variables"):
|
||||
variables += self.custom_variables
|
||||
|
||||
return variables
|
||||
@@ -62,9 +63,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
:rtype: list[Location]
|
||||
"""
|
||||
domains = [
|
||||
getattr(self, f'{t}_domain')
|
||||
for t in ['spatial', 'temporal', 'parameter']
|
||||
if hasattr(self, f'{t}_domain')
|
||||
getattr(self, f"{t}_domain")
|
||||
for t in ["spatial", "temporal", "parameter"]
|
||||
if hasattr(self, f"{t}_domain")
|
||||
]
|
||||
|
||||
if len(domains) == 1:
|
||||
@@ -77,7 +78,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
[domain.update(d) for d in domains]
|
||||
return domain
|
||||
else:
|
||||
raise RuntimeError('different domains')
|
||||
raise RuntimeError("different domains")
|
||||
|
||||
@input_variables.setter
|
||||
def input_variables(self, variables):
|
||||
@@ -105,24 +106,27 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
if hasattr(condition, 'input_points'):
|
||||
if hasattr(condition, "input_points"):
|
||||
samples = condition.input_points
|
||||
self.input_pts[condition_name] = samples
|
||||
self._have_sampled_points[condition_name] = True
|
||||
if hasattr(self, 'unknown_parameter_domain'):
|
||||
if hasattr(self, "unknown_parameter_domain"):
|
||||
# initialize the unknown parameters of the inverse problem given
|
||||
# the domain the user gives
|
||||
self.unknown_parameters = {}
|
||||
for i, var in enumerate(self.unknown_variables):
|
||||
range_var = self.unknown_parameter_domain.range_[var]
|
||||
tensor_var = torch.rand(1, requires_grad=True) * range_var[1] + range_var[0]
|
||||
self.unknown_parameters[var] = torch.nn.Parameter(tensor_var)
|
||||
tensor_var = (
|
||||
torch.rand(1, requires_grad=True) * range_var[1]
|
||||
+ range_var[0]
|
||||
)
|
||||
self.unknown_parameters[var] = torch.nn.Parameter(
|
||||
tensor_var
|
||||
)
|
||||
|
||||
def discretise_domain(self,
|
||||
n,
|
||||
mode='random',
|
||||
variables='all',
|
||||
locations='all'):
|
||||
def discretise_domain(
|
||||
self, n, mode="random", variables="all", locations="all"
|
||||
):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
@@ -157,28 +161,32 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
# check consistency mode
|
||||
check_consistency(mode, str)
|
||||
if mode not in ['random', 'grid', 'lh', 'chebyshev', 'latin']:
|
||||
raise TypeError(f'mode {mode} not valid.')
|
||||
if mode not in ["random", "grid", "lh", "chebyshev", "latin"]:
|
||||
raise TypeError(f"mode {mode} not valid.")
|
||||
|
||||
# check consistency variables
|
||||
if variables == 'all':
|
||||
if variables == "all":
|
||||
variables = self.input_variables
|
||||
else:
|
||||
check_consistency(variables, str)
|
||||
|
||||
if sorted(variables) != sorted(self.input_variables):
|
||||
TypeError(f'Wrong variables for sampling. Variables ',
|
||||
f'should be in {self.input_variables}.')
|
||||
TypeError(
|
||||
f"Wrong variables for sampling. Variables ",
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
|
||||
# check consistency location
|
||||
if locations == 'all':
|
||||
if locations == "all":
|
||||
locations = [condition for condition in self.conditions]
|
||||
else:
|
||||
check_consistency(locations, str)
|
||||
|
||||
if sorted(locations) != sorted(self.conditions):
|
||||
TypeError(f'Wrong locations for sampling. Location ',
|
||||
f'should be in {self.conditions}.')
|
||||
TypeError(
|
||||
f"Wrong locations for sampling. Location ",
|
||||
f"should be in {self.conditions}.",
|
||||
)
|
||||
|
||||
# sampling
|
||||
for location in locations:
|
||||
@@ -208,10 +216,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
# the condition is sampled if input_pts contains all labels
|
||||
if sorted(self.input_pts[location].labels) == sorted(
|
||||
self.input_variables):
|
||||
self.input_variables
|
||||
):
|
||||
self._have_sampled_points[location] = True
|
||||
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
Adding points to the already sampled points.
|
||||
@@ -221,8 +229,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
|
||||
if sorted(new_points.keys()) != sorted(self.conditions):
|
||||
TypeError(f'Wrong locations for new points. Location ',
|
||||
f'should be in {self.conditions}.')
|
||||
TypeError(
|
||||
f"Wrong locations for new points. Location ",
|
||||
f"should be in {self.conditions}.",
|
||||
)
|
||||
|
||||
for location in new_points.keys():
|
||||
# extract old and new points
|
||||
@@ -231,11 +241,14 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
# if they don't have the same variables error
|
||||
if sorted(old_pts.labels) != sorted(new_pts.labels):
|
||||
TypeError(f'Not matching variables for old and new points '
|
||||
f'in condition {location}.')
|
||||
TypeError(
|
||||
f"Not matching variables for old and new points "
|
||||
f"in condition {location}."
|
||||
)
|
||||
if old_pts.labels != new_pts.labels:
|
||||
new_pts = torch.hstack(
|
||||
[new_pts.extract([i]) for i in old_pts.labels])
|
||||
[new_pts.extract([i]) for i in old_pts.labels]
|
||||
)
|
||||
new_pts.labels = old_pts.labels
|
||||
|
||||
# merging
|
||||
@@ -266,4 +279,3 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
if not is_sample:
|
||||
not_sampled.append(condition_name)
|
||||
return not_sampled
|
||||
|
||||
|
||||
Reference in New Issue
Block a user