Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod
from ..utils import check_consistency
from ..domain import DomainInterface, CartesianDomain
from ..condition.domain_equation_condition import DomainEquationCondition
from ..condition import InputPointsEquationCondition
from copy import deepcopy
from .. import LabelTensor
from ..utils import merge_tensors
@@ -55,8 +54,8 @@ class AbstractProblem(metaclass=ABCMeta):
def input_pts(self):
to_return = {}
for cond_name, cond in self.conditions.items():
if hasattr(cond, "input_points"):
to_return[cond_name] = cond.input_points
if hasattr(cond, "input"):
to_return[cond_name] = cond.input
elif hasattr(cond, "domain"):
to_return[cond_name] = self._discretised_domains[cond.domain]
return to_return