Dev Update (#582)

* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <dariocos99@gmail.com>

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
Dario Coscia
2025-06-13 17:34:37 +02:00
committed by GitHub
parent 6b355b45de
commit 7bf7d34d0f
40 changed files with 1963 additions and 581 deletions

View File

@@ -1,12 +1,13 @@
"""Module for the AbstractProblem class."""
from abc import ABCMeta, abstractmethod
import warnings
from copy import deepcopy
from ..utils import check_consistency
from ..domain import DomainInterface, CartesianDomain
from ..condition.domain_equation_condition import DomainEquationCondition
from ..label_tensor import LabelTensor
from ..utils import merge_tensors
from ..utils import merge_tensors, custom_warning_format
class AbstractProblem(metaclass=ABCMeta):
@@ -23,14 +24,11 @@ class AbstractProblem(metaclass=ABCMeta):
Initialization of the :class:`AbstractProblem` class.
"""
self._discretised_domains = {}
# create collector to manage problem data
# create hook conditions <-> problems
for condition_name in self.conditions:
self.conditions[condition_name].problem = self
self._batching_dimension = 0
# Store in domains dict all the domains object directly passed to
# ConditionInterface. Done for back compatibility with PINA <0.2
if not hasattr(self, "domains"):
@@ -41,41 +39,57 @@ class AbstractProblem(metaclass=ABCMeta):
self.domains[cond_name] = cond.domain
cond.domain = cond_name
self._collected_data = {}
@property
def batching_dimension(self):
def collected_data(self):
"""
Get batching dimension.
Return the collected data from the problem's conditions. If some domains
are not sampled, they will not be returned by collected data.
:return: The batching dimension.
:rtype: int
:return: The collected data. Keys are condition names, and values are
dictionaries containing the input points and the corresponding
equations or target points.
:rtype: dict
"""
return self._batching_dimension
@batching_dimension.setter
def batching_dimension(self, value):
"""
Set the batching dimension.
:param int value: The batching dimension.
"""
self._batching_dimension = value
# collect data so far
self.collect_data()
# raise warning if some sample data are missing
if not self.are_all_domains_discretised:
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=RuntimeWarning)
warning_message = "\n".join(
[
f"""{" " * 13} ---> Domain {key} {
"sampled" if key in self.discretised_domains
else
"not sampled"}"""
for key in self.domains
]
)
warnings.warn(
"Some of the domains are still not sampled. Consider calling "
"problem.discretise_domain function for all domains before "
"accessing the collected data:\n"
f"{warning_message}",
RuntimeWarning,
)
return self._collected_data
# back compatibility 0.1
@property
def input_pts(self):
"""
Return a dictionary mapping condition names to their corresponding
input points.
input points. If some domains are not sampled, they will not be returned
and the corresponding condition will be empty.
:return: The input points of the problem.
:rtype: dict
"""
to_return = {}
for cond_name, cond in self.conditions.items():
if hasattr(cond, "input"):
to_return[cond_name] = cond.input
elif hasattr(cond, "domain"):
to_return[cond_name] = self._discretised_domains[cond.domain]
for cond_name, data in self.collected_data.items():
to_return[cond_name] = data["input"]
return to_return
@property
@@ -300,3 +314,29 @@ class AbstractProblem(metaclass=ABCMeta):
self.discretised_domains[k] = LabelTensor.vstack(
[self.discretised_domains[k], v]
)
def collect_data(self):
"""
Aggregate data from the problem's conditions into a single dictionary.
"""
data = {}
# Iterate over the conditions and collect data
for condition_name in self.conditions:
condition = self.conditions[condition_name]
# Check if the condition has an domain attribute
if hasattr(condition, "domain"):
# Only store the discretisation points if the domain is
# in the dictionary
if condition.domain in self.discretised_domains:
samples = self.discretised_domains[condition.domain]
data[condition_name] = {
"input": samples,
"equation": condition.equation,
}
else:
# If the condition does not have a domain attribute, store
# the input and target points
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
data[condition_name] = dict(zip(keys, values))
self._collected_data = data