Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
"""Module for the AbstractProblem class."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface, CartesianDomain
|
||||
from ..condition.domain_equation_condition import DomainEquationCondition
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import merge_tensors
|
||||
from ..utils import merge_tensors, custom_warning_format
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
@@ -23,14 +24,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
Initialization of the :class:`AbstractProblem` class.
|
||||
"""
|
||||
self._discretised_domains = {}
|
||||
# create collector to manage problem data
|
||||
|
||||
# create hook conditions <-> problems
|
||||
for condition_name in self.conditions:
|
||||
self.conditions[condition_name].problem = self
|
||||
|
||||
self._batching_dimension = 0
|
||||
|
||||
# Store in domains dict all the domains object directly passed to
|
||||
# ConditionInterface. Done for back compatibility with PINA <0.2
|
||||
if not hasattr(self, "domains"):
|
||||
@@ -41,41 +39,57 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self.domains[cond_name] = cond.domain
|
||||
cond.domain = cond_name
|
||||
|
||||
self._collected_data = {}
|
||||
|
||||
@property
|
||||
def batching_dimension(self):
|
||||
def collected_data(self):
|
||||
"""
|
||||
Get batching dimension.
|
||||
Return the collected data from the problem's conditions. If some domains
|
||||
are not sampled, they will not be returned by collected data.
|
||||
|
||||
:return: The batching dimension.
|
||||
:rtype: int
|
||||
:return: The collected data. Keys are condition names, and values are
|
||||
dictionaries containing the input points and the corresponding
|
||||
equations or target points.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._batching_dimension
|
||||
|
||||
@batching_dimension.setter
|
||||
def batching_dimension(self, value):
|
||||
"""
|
||||
Set the batching dimension.
|
||||
|
||||
:param int value: The batching dimension.
|
||||
"""
|
||||
self._batching_dimension = value
|
||||
# collect data so far
|
||||
self.collect_data()
|
||||
# raise warning if some sample data are missing
|
||||
if not self.are_all_domains_discretised:
|
||||
warnings.formatwarning = custom_warning_format
|
||||
warnings.filterwarnings("always", category=RuntimeWarning)
|
||||
warning_message = "\n".join(
|
||||
[
|
||||
f"""{" " * 13} ---> Domain {key} {
|
||||
"sampled" if key in self.discretised_domains
|
||||
else
|
||||
"not sampled"}"""
|
||||
for key in self.domains
|
||||
]
|
||||
)
|
||||
warnings.warn(
|
||||
"Some of the domains are still not sampled. Consider calling "
|
||||
"problem.discretise_domain function for all domains before "
|
||||
"accessing the collected data:\n"
|
||||
f"{warning_message}",
|
||||
RuntimeWarning,
|
||||
)
|
||||
return self._collected_data
|
||||
|
||||
# back compatibility 0.1
|
||||
@property
|
||||
def input_pts(self):
|
||||
"""
|
||||
Return a dictionary mapping condition names to their corresponding
|
||||
input points.
|
||||
input points. If some domains are not sampled, they will not be returned
|
||||
and the corresponding condition will be empty.
|
||||
|
||||
:return: The input points of the problem.
|
||||
:rtype: dict
|
||||
"""
|
||||
to_return = {}
|
||||
for cond_name, cond in self.conditions.items():
|
||||
if hasattr(cond, "input"):
|
||||
to_return[cond_name] = cond.input
|
||||
elif hasattr(cond, "domain"):
|
||||
to_return[cond_name] = self._discretised_domains[cond.domain]
|
||||
for cond_name, data in self.collected_data.items():
|
||||
to_return[cond_name] = data["input"]
|
||||
return to_return
|
||||
|
||||
@property
|
||||
@@ -300,3 +314,29 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self.discretised_domains[k] = LabelTensor.vstack(
|
||||
[self.discretised_domains[k], v]
|
||||
)
|
||||
|
||||
def collect_data(self):
|
||||
"""
|
||||
Aggregate data from the problem's conditions into a single dictionary.
|
||||
"""
|
||||
data = {}
|
||||
# Iterate over the conditions and collect data
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
# Check if the condition has an domain attribute
|
||||
if hasattr(condition, "domain"):
|
||||
# Only store the discretisation points if the domain is
|
||||
# in the dictionary
|
||||
if condition.domain in self.discretised_domains:
|
||||
samples = self.discretised_domains[condition.domain]
|
||||
data[condition_name] = {
|
||||
"input": samples,
|
||||
"equation": condition.equation,
|
||||
}
|
||||
else:
|
||||
# If the condition does not have a domain attribute, store
|
||||
# the input and target points
|
||||
keys = condition.__slots__
|
||||
values = [getattr(condition, name) for name in keys]
|
||||
data[condition_name] = dict(zip(keys, values))
|
||||
self._collected_data = data
|
||||
|
||||
Reference in New Issue
Block a user