supervised working
This commit is contained in:
@@ -5,6 +5,8 @@ from ..utils import merge_tensors, check_consistency
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
from .. import LabelTensor
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
@@ -18,17 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
# variable storing all points
|
||||
self.input_pts = {}
|
||||
|
||||
# varible to check if sampling is done. If no location
|
||||
# element is presented in Condition this variable is set to true
|
||||
self._have_sampled_points = {}
|
||||
self._discretized_domains = {}
|
||||
|
||||
for name, domain in self.domains.items():
|
||||
if isinstance(domain, (torch.Tensor, LabelTensor)):
|
||||
self._discretized_domains[name] = domain
|
||||
|
||||
for condition_name in self.conditions:
|
||||
self._have_sampled_points[condition_name] = False
|
||||
self.conditions[condition_name]._problem = self
|
||||
# # variable storing all points
|
||||
# self.input_pts = {}
|
||||
|
||||
# put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
# # varible to check if sampling is done. If no location
|
||||
# # element is presented in Condition this variable is set to true
|
||||
# self._have_sampled_points = {}
|
||||
# for condition_name in self.conditions:
|
||||
# self._have_sampled_points[condition_name] = False
|
||||
|
||||
# # put in self.input_pts all the points that we don't need to sample
|
||||
# self._span_condition_points()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
@@ -63,15 +74,20 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
variables += self.spatial_variables
|
||||
if hasattr(self, "temporal_variable"):
|
||||
variables += self.temporal_variable
|
||||
if hasattr(self, "parameters"):
|
||||
if hasattr(self, "unknown_parameters"):
|
||||
variables += self.parameters
|
||||
if hasattr(self, "custom_variables"):
|
||||
variables += self.custom_variables
|
||||
|
||||
return variables
|
||||
|
||||
@input_variables.setter
|
||||
def input_variables(self, variables):
|
||||
raise RuntimeError
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
@abstractmethod
|
||||
def domains(self):
|
||||
"""
|
||||
The domain(s) where the conditions of the AbstractProblem are valid.
|
||||
If more than one domain type is passed, a list of Location is
|
||||
@@ -80,27 +96,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
:return: the domain(s) of ``self``
|
||||
:rtype: list[Location]
|
||||
"""
|
||||
domains = [
|
||||
getattr(self, f"{t}_domain")
|
||||
for t in ["spatial", "temporal", "parameter"]
|
||||
if hasattr(self, f"{t}_domain")
|
||||
]
|
||||
|
||||
if len(domains) == 1:
|
||||
return domains[0]
|
||||
elif len(domains) == 0:
|
||||
raise RuntimeError
|
||||
|
||||
if len(set(map(type, domains))) == 1:
|
||||
domain = domains[0].__class__({})
|
||||
[domain.update(d) for d in domains]
|
||||
return domain
|
||||
else:
|
||||
raise RuntimeError("different domains")
|
||||
|
||||
@input_variables.setter
|
||||
def input_variables(self, variables):
|
||||
raise RuntimeError
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -116,7 +112,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
"""
|
||||
The conditions of the problem.
|
||||
"""
|
||||
pass
|
||||
return self._conditions
|
||||
|
||||
|
||||
|
||||
def _span_condition_points(self):
|
||||
"""
|
||||
@@ -281,28 +279,4 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
# merging
|
||||
merged_pts = torch.vstack([old_pts, new_pts])
|
||||
merged_pts.labels = old_pts.labels
|
||||
self.input_pts[location] = merged_pts
|
||||
|
||||
@property
|
||||
def have_sampled_points(self):
|
||||
"""
|
||||
Check if all points for
|
||||
``Location`` are sampled.
|
||||
"""
|
||||
return all(self._have_sampled_points.values())
|
||||
|
||||
@property
|
||||
def not_sampled_points(self):
|
||||
"""
|
||||
Check which points are
|
||||
not sampled.
|
||||
"""
|
||||
# variables which are not sampled
|
||||
not_sampled = None
|
||||
if self.have_sampled_points is False:
|
||||
# check which one are not sampled:
|
||||
not_sampled = []
|
||||
for condition_name, is_sample in self._have_sampled_points.items():
|
||||
if not is_sample:
|
||||
not_sampled.append(condition_name)
|
||||
return not_sampled
|
||||
self.input_pts[location] = merged_pts
|
||||
Reference in New Issue
Block a user