not bad this setup

This commit is contained in:
2025-12-18 09:30:21 +01:00
parent 4fdf817d75
commit 0a034225ef
4 changed files with 120 additions and 58 deletions

View File

@@ -23,6 +23,8 @@ class GraphDataModule(LightningDataModule):
build_radial_graph: bool = False,
radius: float = None,
unrolling_steps: int = 1,
aggregate_timesteps: int = 1,
min_normalized_diff: float = 1e-3,
):
super().__init__()
self.hf_repo = hf_repo
@@ -35,6 +37,9 @@ class GraphDataModule(LightningDataModule):
None,
)
self.unrolling_steps = unrolling_steps
self.aggregate_timesteps = aggregate_timesteps
self.min_normalized_diff = min_normalized_diff
self.geometry_dict = {}
self.train_size = train_size
self.val_size = val_size
@@ -109,10 +114,14 @@ class GraphDataModule(LightningDataModule):
dim=0,
)
)
# print(temperatures.shape)
if not test:
for t in range(1, temperatures.size(0)):
diff = temperatures[t, :] - temperatures[t - 1, :]
norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2)
if norm_diff < self.min_normalized_diff:
temperatures = temperatures[: t + 1, :]
break
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
if self.build_radial_graph:
raise NotImplementedError(
"Radial graph building not implemented yet."
@@ -224,7 +233,7 @@ class GraphDataModule(LightningDataModule):
batch_size=self.batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
pin_memory=False,
)
def val_dataloader(self):
@@ -237,7 +246,7 @@ class GraphDataModule(LightningDataModule):
batch_size=128,
shuffle=False,
num_workers=8,
pin_memory=True,
pin_memory=False,
)
def test_dataloader(self):
@@ -247,5 +256,5 @@ class GraphDataModule(LightningDataModule):
batch_size=1,
shuffle=False,
num_workers=8,
pin_memory=True,
pin_memory=False,
)