Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -1,3 +1,4 @@
from . import LabelTensor
from .utils import check_consistency, merge_tensors
@@ -66,9 +67,12 @@ class Collector:
for loc in sample_locations:
# get condition
condition = self.problem.conditions[loc]
condition_domain = condition.domain
if isinstance(condition_domain, str):
condition_domain = self.problem.domains[condition_domain]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self._is_conditions_ready[loc]):
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
@@ -84,10 +88,11 @@ class Collector:
# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
condition_domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
@@ -110,5 +115,6 @@ class Collector:
if not self._is_conditions_ready[k]:
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)
self.data_collections[k]['input_points'] = LabelTensor.vstack(
[self.data_collections[k][
'input_points'], v])