supervised working

This commit is contained in:
Nicola Demo
2024-08-08 16:19:52 +02:00
parent 5245a0b68c
commit 9d9c2aa23e
61 changed files with 375 additions and 262 deletions

View File

@@ -5,6 +5,8 @@ from ..utils import merge_tensors, check_consistency
from copy import deepcopy
import torch
from .. import LabelTensor
class AbstractProblem(metaclass=ABCMeta):
"""
@@ -18,17 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
def __init__(self):
# variable storing all points
self.input_pts = {}
# varible to check if sampling is done. If no location
# element is presented in Condition this variable is set to true
self._have_sampled_points = {}
self._discretized_domains = {}
for name, domain in self.domains.items():
if isinstance(domain, (torch.Tensor, LabelTensor)):
self._discretized_domains[name] = domain
for condition_name in self.conditions:
self._have_sampled_points[condition_name] = False
self.conditions[condition_name]._problem = self
# # variable storing all points
# self.input_pts = {}
# put in self.input_pts all the points that we don't need to sample
self._span_condition_points()
# # varible to check if sampling is done. If no location
# # element is presented in Condition this variable is set to true
# self._have_sampled_points = {}
# for condition_name in self.conditions:
# self._have_sampled_points[condition_name] = False
# # put in self.input_pts all the points that we don't need to sample
# self._span_condition_points()
def __deepcopy__(self, memo):
"""
@@ -63,15 +74,20 @@ class AbstractProblem(metaclass=ABCMeta):
variables += self.spatial_variables
if hasattr(self, "temporal_variable"):
variables += self.temporal_variable
if hasattr(self, "parameters"):
if hasattr(self, "unknown_parameters"):
variables += self.parameters
if hasattr(self, "custom_variables"):
variables += self.custom_variables
return variables
@input_variables.setter
def input_variables(self, variables):
raise RuntimeError
@property
def domain(self):
@abstractmethod
def domains(self):
"""
The domain(s) where the conditions of the AbstractProblem are valid.
If more than one domain type is passed, a list of Location is
@@ -80,27 +96,7 @@ class AbstractProblem(metaclass=ABCMeta):
:return: the domain(s) of ``self``
:rtype: list[Location]
"""
domains = [
getattr(self, f"{t}_domain")
for t in ["spatial", "temporal", "parameter"]
if hasattr(self, f"{t}_domain")
]
if len(domains) == 1:
return domains[0]
elif len(domains) == 0:
raise RuntimeError
if len(set(map(type, domains))) == 1:
domain = domains[0].__class__({})
[domain.update(d) for d in domains]
return domain
else:
raise RuntimeError("different domains")
@input_variables.setter
def input_variables(self, variables):
raise RuntimeError
pass
@property
@abstractmethod
@@ -116,7 +112,9 @@ class AbstractProblem(metaclass=ABCMeta):
"""
The conditions of the problem.
"""
pass
return self._conditions
def _span_condition_points(self):
"""
@@ -281,28 +279,4 @@ class AbstractProblem(metaclass=ABCMeta):
# merging
merged_pts = torch.vstack([old_pts, new_pts])
merged_pts.labels = old_pts.labels
self.input_pts[location] = merged_pts
@property
def have_sampled_points(self):
"""
Check if all points for
``Location`` are sampled.
"""
return all(self._have_sampled_points.values())
@property
def not_sampled_points(self):
"""
Check which points are
not sampled.
"""
# variables which are not sampled
not_sampled = None
if self.have_sampled_points is False:
# check which one are not sampled:
not_sampled = []
for condition_name, is_sample in self._have_sampled_points.items():
if not is_sample:
not_sampled.append(condition_name)
return not_sampled
self.input_pts[location] = merged_pts