add model and fix module and datamodule

This commit is contained in:
FilippoOlivo
2025-12-01 10:06:07 +01:00
parent 88bc5c05e4
commit c36c59d08d
4 changed files with 359 additions and 212 deletions

View File

@@ -85,7 +85,7 @@ class GraphDataModule(LightningDataModule):
conductivity = torch.tensor(
geometry["conductivity"], dtype=torch.float32
)
temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:2]
temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40]
times = torch.tensor(snapshot["times"], dtype=torch.float32)
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
@@ -131,9 +131,7 @@ class GraphDataModule(LightningDataModule):
data = []
for i in range(n_data):
x = temperatures[i, :].unsqueeze(-1)
print(x.shape)
y = temperatures[i + 1 : i + 1 + self.unrolling_steps, :].unsqueeze(-1).permute(1,0,2)
# print(y.shape)
data.append(MeshData(
x=x,
y=y,
@@ -187,9 +185,9 @@ class GraphDataModule(LightningDataModule):
def train_dataloader(self):
# ds = self.create_autoregressive_datasets(dataset="train")
# self.train_dataset = ds
print(type(self.train_data[0]))
# print(type(self.train_data[0]))
ds = [i for data in self.train_data for i in data]
print(type(ds[0]))
# print(type(ds[0]))
return DataLoader(
ds,
batch_size=self.batch_size,
@@ -202,7 +200,7 @@ class GraphDataModule(LightningDataModule):
ds = [i for data in self.val_data for i in data]
return DataLoader(
ds,
batch_size=self.batch_size,
batch_size=128,
shuffle=False,
num_workers=8,
pin_memory=True,