fix callbacks
This commit is contained in:
@@ -5,7 +5,6 @@ from lightning.pytorch import Callback
|
|||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
from ..utils import check_consistency, is_function
|
from ..utils import check_consistency, is_function
|
||||||
from ..condition import InputTargetCondition
|
from ..condition import InputTargetCondition
|
||||||
from ..data.dataset import PinaGraphDataset
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizerDataCallback(Callback):
|
class NormalizerDataCallback(Callback):
|
||||||
@@ -122,7 +121,10 @@ class NormalizerDataCallback(Callback):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Ensure datsets are not graph-based
|
# Ensure datsets are not graph-based
|
||||||
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
|
if any(
|
||||||
|
ds.is_graph_dataset
|
||||||
|
for ds in trainer.datamodule.train_dataset.values()
|
||||||
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"NormalizerDataCallback is not compatible with "
|
"NormalizerDataCallback is not compatible with "
|
||||||
"graph-based datasets."
|
"graph-based datasets."
|
||||||
@@ -164,8 +166,8 @@ class NormalizerDataCallback(Callback):
|
|||||||
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
|
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
|
||||||
"""
|
"""
|
||||||
for cond in conditions:
|
for cond in conditions:
|
||||||
if cond in dataset.conditions_dict:
|
if cond in dataset:
|
||||||
data = dataset.conditions_dict[cond][self.apply_to]
|
data = dataset[cond].data[self.apply_to]
|
||||||
shift = self.shift_fn(data)
|
shift = self.shift_fn(data)
|
||||||
scale = self.scale_fn(data)
|
scale = self.scale_fn(data)
|
||||||
self._normalizer[cond] = {
|
self._normalizer[cond] = {
|
||||||
@@ -197,25 +199,20 @@ class NormalizerDataCallback(Callback):
|
|||||||
|
|
||||||
:param PinaDataset dataset: The dataset to be normalized.
|
:param PinaDataset dataset: The dataset to be normalized.
|
||||||
"""
|
"""
|
||||||
# Initialize update dictionary
|
|
||||||
update_dataset_dict = {}
|
|
||||||
|
|
||||||
# Iterate over conditions and apply normalization
|
# Iterate over conditions and apply normalization
|
||||||
for cond, norm_params in self.normalizer.items():
|
for cond, norm_params in self.normalizer.items():
|
||||||
points = dataset.conditions_dict[cond][self.apply_to]
|
update_dataset_dict = {}
|
||||||
|
points = dataset[cond].data[self.apply_to]
|
||||||
scale = norm_params["scale"]
|
scale = norm_params["scale"]
|
||||||
shift = norm_params["shift"]
|
shift = norm_params["shift"]
|
||||||
normalized_points = self._norm_fn(points, scale, shift)
|
normalized_points = self._norm_fn(points, scale, shift)
|
||||||
update_dataset_dict[cond] = {
|
update_dataset_dict[self.apply_to] = (
|
||||||
self.apply_to: (
|
LabelTensor(normalized_points, points.labels)
|
||||||
LabelTensor(normalized_points, points.labels)
|
if isinstance(points, LabelTensor)
|
||||||
if isinstance(points, LabelTensor)
|
else normalized_points
|
||||||
else normalized_points
|
)
|
||||||
)
|
dataset[cond].data.update(update_dataset_dict)
|
||||||
}
|
|
||||||
|
|
||||||
# Update the dataset in-place
|
|
||||||
dataset.update_data(update_dataset_dict)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def normalizer(self):
|
def normalizer(self):
|
||||||
|
|||||||
@@ -133,13 +133,12 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
|
|||||||
|
|
||||||
:param PINNInterface solver: The solver object.
|
:param PINNInterface solver: The solver object.
|
||||||
"""
|
"""
|
||||||
new_points = {}
|
|
||||||
for name in self._condition_to_update:
|
for name in self._condition_to_update:
|
||||||
current_points = self.dataset.conditions_dict[name]["input"]
|
new_points = {}
|
||||||
new_points[name] = {
|
current_points = self.dataset[name].data["input"]
|
||||||
"input": self.sample(current_points, name, solver)
|
new_points["input"] = self.sample(current_points, name, solver)
|
||||||
}
|
|
||||||
self.dataset.update_data(new_points)
|
self.dataset[name].update_data(new_points)
|
||||||
|
|
||||||
def _compute_population_size(self, conditions):
|
def _compute_population_size(self, conditions):
|
||||||
"""
|
"""
|
||||||
@@ -150,6 +149,5 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
|
|||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
cond: len(self.dataset.conditions_dict[cond]["input"])
|
cond: len(self.dataset[cond].data["input"]) for cond in conditions
|
||||||
for cond in conditions
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user