not bad this setup
This commit is contained in:
@@ -23,6 +23,8 @@ class GraphDataModule(LightningDataModule):
|
||||
build_radial_graph: bool = False,
|
||||
radius: float = None,
|
||||
unrolling_steps: int = 1,
|
||||
aggregate_timesteps: int = 1,
|
||||
min_normalized_diff: float = 1e-3,
|
||||
):
|
||||
super().__init__()
|
||||
self.hf_repo = hf_repo
|
||||
@@ -35,6 +37,9 @@ class GraphDataModule(LightningDataModule):
|
||||
None,
|
||||
)
|
||||
self.unrolling_steps = unrolling_steps
|
||||
self.aggregate_timesteps = aggregate_timesteps
|
||||
self.min_normalized_diff = min_normalized_diff
|
||||
|
||||
self.geometry_dict = {}
|
||||
self.train_size = train_size
|
||||
self.val_size = val_size
|
||||
@@ -109,10 +114,14 @@ class GraphDataModule(LightningDataModule):
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# print(temperatures.shape)
|
||||
|
||||
if not test:
|
||||
for t in range(1, temperatures.size(0)):
|
||||
diff = temperatures[t, :] - temperatures[t - 1, :]
|
||||
norm_diff = torch.norm(diff, p=2) / torch.norm(temperatures[t - 1], p=2)
|
||||
if norm_diff < self.min_normalized_diff:
|
||||
temperatures = temperatures[: t + 1, :]
|
||||
break
|
||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||
|
||||
if self.build_radial_graph:
|
||||
raise NotImplementedError(
|
||||
"Radial graph building not implemented yet."
|
||||
@@ -224,7 +233,7 @@ class GraphDataModule(LightningDataModule):
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
@@ -237,7 +246,7 @@ class GraphDataModule(LightningDataModule):
|
||||
batch_size=128,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
@@ -247,5 +256,5 @@ class GraphDataModule(LightningDataModule):
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user