add model and fix module and datamodule
This commit is contained in:
@@ -85,7 +85,7 @@ class GraphDataModule(LightningDataModule):
|
||||
conductivity = torch.tensor(
|
||||
geometry["conductivity"], dtype=torch.float32
|
||||
)
|
||||
temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:2]
|
||||
temperatures = torch.tensor(snapshot["temperatures"], dtype=torch.float32)[:40]
|
||||
times = torch.tensor(snapshot["times"], dtype=torch.float32)
|
||||
|
||||
pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2]
|
||||
@@ -131,9 +131,7 @@ class GraphDataModule(LightningDataModule):
|
||||
data = []
|
||||
for i in range(n_data):
|
||||
x = temperatures[i, :].unsqueeze(-1)
|
||||
print(x.shape)
|
||||
y = temperatures[i + 1 : i + 1 + self.unrolling_steps, :].unsqueeze(-1).permute(1,0,2)
|
||||
# print(y.shape)
|
||||
data.append(MeshData(
|
||||
x=x,
|
||||
y=y,
|
||||
@@ -187,9 +185,9 @@ class GraphDataModule(LightningDataModule):
|
||||
def train_dataloader(self):
|
||||
# ds = self.create_autoregressive_datasets(dataset="train")
|
||||
# self.train_dataset = ds
|
||||
print(type(self.train_data[0]))
|
||||
# print(type(self.train_data[0]))
|
||||
ds = [i for data in self.train_data for i in data]
|
||||
print(type(ds[0]))
|
||||
# print(type(ds[0]))
|
||||
return DataLoader(
|
||||
ds,
|
||||
batch_size=self.batch_size,
|
||||
@@ -202,7 +200,7 @@ class GraphDataModule(LightningDataModule):
|
||||
ds = [i for data in self.val_data for i in data]
|
||||
return DataLoader(
|
||||
ds,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=128,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
|
||||
Reference in New Issue
Block a user