update model, model and datamodule
This commit is contained in:
@@ -35,14 +35,6 @@ class GraphDataModule(LightningDataModule):
|
|||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
|
dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name]
|
||||||
geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name]
|
geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name]
|
||||||
# data = [
|
|
||||||
# self._build_dataset(snapshot, geometry)
|
|
||||||
# for snapshot, geometry in tqdm(
|
|
||||||
# zip(hf_dataset, self.geometry),
|
|
||||||
# desc="Building graphs",
|
|
||||||
# total=len(hf_dataset),
|
|
||||||
# )
|
|
||||||
# ]
|
|
||||||
|
|
||||||
total_len = len(dataset)
|
total_len = len(dataset)
|
||||||
train_len = int(self.train_size * total_len)
|
train_len = int(self.train_size * total_len)
|
||||||
@@ -127,7 +119,7 @@ class GraphDataModule(LightningDataModule):
|
|||||||
pos=pos,
|
pos=pos,
|
||||||
edge_attr=edge_attr,
|
edge_attr=edge_attr,
|
||||||
y=temperature.unsqueeze(-1),
|
y=temperature.unsqueeze(-1),
|
||||||
boundary_mask=torch.tensor(0), # Fake value (to fix)
|
boundary_mask=boundary_mask,
|
||||||
boundary_values=torch.tensor(0), # Fake value (to fix)
|
boundary_values=torch.tensor(0), # Fake value (to fix)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,8 +135,6 @@ class GraphDataModule(LightningDataModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def setup(self, stage: str = None):
|
def setup(self, stage: str = None):
|
||||||
print(type(self.dataset_dict["train"]))
|
|
||||||
|
|
||||||
if stage == "fit" or stage is None:
|
if stage == "fit" or stage is None:
|
||||||
self.train_data = [
|
self.train_data = [
|
||||||
self._build_dataset(snap, geom)
|
self._build_dataset(snap, geom)
|
||||||
|
|||||||
@@ -77,13 +77,6 @@ class ConditionalGNOBlock(MessagePassing):
|
|||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.x_net = nn.Sequential(
|
|
||||||
nn.Linear(hidden_ch, hidden_ch * 2),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(hidden_ch * 2, hidden_ch),
|
|
||||||
nn.GELU(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.c_ij_net = nn.Sequential(
|
self.c_ij_net = nn.Sequential(
|
||||||
nn.Linear(hidden_ch, hidden_ch // 2),
|
nn.Linear(hidden_ch, hidden_ch // 2),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
@@ -116,9 +109,6 @@ class ConditionalGNOBlock(MessagePassing):
|
|||||||
c_ij = 0.5 * (c_i + c_j)
|
c_ij = 0.5 * (c_i + c_j)
|
||||||
gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1))
|
gamma = self.gamma_net(torch.cat([x_i, x_j], dim=-1))
|
||||||
gate = self.edge_attr_net(edge_attr)
|
gate = self.edge_attr_net(edge_attr)
|
||||||
m = (
|
|
||||||
gamma * self.diff_net(x_j - x_i) + (1 - gamma) * self.x_net(x_j)
|
|
||||||
) * gate
|
|
||||||
m = self.diff_net(x_j - x_i) * gate
|
m = self.diff_net(x_j - x_i) * gate
|
||||||
m = m * self.c_ij_net(c_ij)
|
m = m * self.c_ij_net(c_ij)
|
||||||
return m
|
return m
|
||||||
@@ -158,17 +148,20 @@ class GatingGNO(nn.Module):
|
|||||||
plot_results=False,
|
plot_results=False,
|
||||||
batch=None,
|
batch=None,
|
||||||
pos=None,
|
pos=None,
|
||||||
|
boundary_mask=None,
|
||||||
):
|
):
|
||||||
x = self.encoder_x(x)
|
x = self.encoder_x(x)
|
||||||
c = self.encoder_c(c)
|
c = self.encoder_c(c)
|
||||||
if plot_results:
|
if plot_results:
|
||||||
x_ = self.dec(x)
|
x_ = self.dec(x)
|
||||||
plot_results_fn(x_, pos, 0, batch=batch)
|
plot_results_fn(x_, pos, 0, batch=batch)
|
||||||
|
bc = x[boundary_mask]
|
||||||
for _ in range(1, unrolling_steps + 1):
|
for _ in range(1, unrolling_steps + 1):
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, blk in enumerate(self.blocks):
|
||||||
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
x = blk(x, c, edge_index, edge_attr=edge_attr)
|
||||||
if plot_results:
|
if plot_results:
|
||||||
x_ = self.dec(x)
|
x_ = self.dec(x)
|
||||||
|
assert bc == x[boundary_mask]
|
||||||
plot_results_fn(x_, pos, i * _, batch=batch)
|
plot_results_fn(x_, pos, i * _, batch=batch)
|
||||||
|
|
||||||
return self.dec(x)
|
return self.dec(x)
|
||||||
|
|||||||
@@ -28,20 +28,18 @@ class GraphSolver(LightningModule):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
c: torch.Tensor,
|
c: torch.Tensor,
|
||||||
boundary: torch.Tensor,
|
|
||||||
boundary_mask: torch.Tensor,
|
|
||||||
edge_index: torch.Tensor,
|
edge_index: torch.Tensor,
|
||||||
edge_attr: torch.Tensor,
|
edge_attr: torch.Tensor,
|
||||||
unrolling_steps: int = None,
|
unrolling_steps: int = None,
|
||||||
|
boundary_mask: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
return self.model(
|
return self.model(
|
||||||
x,
|
x,
|
||||||
c,
|
c,
|
||||||
boundary,
|
|
||||||
boundary_mask,
|
|
||||||
edge_index,
|
edge_index,
|
||||||
edge_attr,
|
edge_attr,
|
||||||
unrolling_steps,
|
unrolling_steps,
|
||||||
|
boundary_mask=boundary_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_loss(self, x, y):
|
def _compute_loss(self, x, y):
|
||||||
@@ -66,11 +64,10 @@ class GraphSolver(LightningModule):
|
|||||||
y_pred = self(
|
y_pred = self(
|
||||||
x,
|
x,
|
||||||
c,
|
c,
|
||||||
batch.boundary_values,
|
|
||||||
batch.boundary_mask,
|
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
edge_attr=edge_attr,
|
edge_attr=edge_attr,
|
||||||
unrolling_steps=self.unrolling_steps,
|
unrolling_steps=self.unrolling_steps,
|
||||||
|
boundary_mask=batch.boundary_mask,
|
||||||
)
|
)
|
||||||
loss = self.loss(y_pred, y)
|
loss = self.loss(y_pred, y)
|
||||||
boundary_loss = self.loss(
|
boundary_loss = self.loss(
|
||||||
@@ -85,8 +82,6 @@ class GraphSolver(LightningModule):
|
|||||||
y_pred = self(
|
y_pred = self(
|
||||||
x,
|
x,
|
||||||
c,
|
c,
|
||||||
batch.boundary_values,
|
|
||||||
batch.boundary_mask,
|
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
edge_attr=edge_attr,
|
edge_attr=edge_attr,
|
||||||
unrolling_steps=self.unrolling_steps,
|
unrolling_steps=self.unrolling_steps,
|
||||||
@@ -104,8 +99,6 @@ class GraphSolver(LightningModule):
|
|||||||
y_pred = self.model(
|
y_pred = self.model(
|
||||||
x,
|
x,
|
||||||
c,
|
c,
|
||||||
batch.boundary_values,
|
|
||||||
batch.boundary_mask,
|
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
edge_attr=edge_attr,
|
edge_attr=edge_attr,
|
||||||
unrolling_steps=self.unrolling_steps,
|
unrolling_steps=self.unrolling_steps,
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from torch_geometric.data import Data
|
|||||||
D_IN_KEYS = "x"
|
D_IN_KEYS = "x"
|
||||||
D_ATTR_KEYS = ["c", "edge_attr"]
|
D_ATTR_KEYS = ["c", "edge_attr"]
|
||||||
D_OUT_KEY = "y"
|
D_OUT_KEY = "y"
|
||||||
D_KEYS = [D_IN_KEYS] + [D_OUT_KEY] + D_ATTR_KEYS
|
D_KEYS = D_ATTR_KEYS + [D_OUT_KEY]
|
||||||
D_BOUNDS_KEYS = "boundary_temperatures"
|
|
||||||
|
|
||||||
|
|
||||||
class Normalizer:
|
class Normalizer:
|
||||||
@@ -28,24 +27,17 @@ class Normalizer:
|
|||||||
std[key] = tmp.std(dim=0, keepdim=True) + 1e-6
|
std[key] = tmp.std(dim=0, keepdim=True) + 1e-6
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
def normalize(self, data):
|
@staticmethod
|
||||||
|
def _apply_input_boundary(data: Data):
|
||||||
|
bc = data.y[data.boundary_mask]
|
||||||
|
data[D_IN_KEYS][data.boundary_mask] = bc
|
||||||
|
|
||||||
|
def normalize(self, data: list[Data]):
|
||||||
for d in data:
|
for d in data:
|
||||||
for key in D_KEYS:
|
for key in D_KEYS:
|
||||||
if not hasattr(d, key):
|
|
||||||
raise AttributeError(f"Manca '{key}' in uno dei Data.")
|
|
||||||
d[key] = (d[key] - self.mean[key]) / self.std[key]
|
d[key] = (d[key] - self.mean[key]) / self.std[key]
|
||||||
self._recompute_boundary_temperatures(data)
|
self._apply_input_boundary(d)
|
||||||
|
return data
|
||||||
def _recompute_boundary_temperatures(self, data):
|
|
||||||
for d in data:
|
|
||||||
bottom_bc = d.y[d.bottom_boundary_ids].median()
|
|
||||||
top_bc = d.y[d.top_boundary_ids].median()
|
|
||||||
left_bc = d.y[d.left_boundary_ids].median()
|
|
||||||
right_bc = d.y[d.right_boundary_ids].median()
|
|
||||||
boundaries_temperatures = torch.tensor(
|
|
||||||
[bottom_bc, right_bc, top_bc, left_bc], dtype=torch.float32
|
|
||||||
)
|
|
||||||
d.boundary_temperatures = boundaries_temperatures.unsqueeze(0)
|
|
||||||
|
|
||||||
def denormalize(self, y: torch.tensor):
|
def denormalize(self, y: torch.tensor):
|
||||||
return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]
|
return y * self.std[D_OUT_KEY] + self.mean[D_OUT_KEY]
|
||||||
|
|||||||
Reference in New Issue
Block a user