Files
thermal-conduction-ml/ThermalSolver/switch_dataloader_callback.py
2025-12-09 09:18:36 +01:00

105 lines
4.0 KiB
Python

import torch
from lightning.pytorch.callbacks import Callback
import os
class SwitchDataLoaderCallback(Callback):
def __init__(
self,
ckpt_path,
increase_unrolling_steps_by,
increase_unrolling_steps_every,
max_unrolling_steps=10,
patience=None,
last_patience=None,
metric="val/loss",
):
super().__init__()
self.ckpt_path = ckpt_path
if os.path.exists(ckpt_path) is False:
os.makedirs(ckpt_path)
self.increase_unrolling_steps_by = increase_unrolling_steps_by
self.increase_unrolling_steps_every = increase_unrolling_steps_every
self.max_unrolling_steps = max_unrolling_steps
self.metric = metric
self.actual_loss = torch.inf
if patience is not None:
self.patience = patience
if last_patience is not None:
self.last_patience = last_patience
self.no_improvement_epochs = 0
self.last_step_reached = False
def on_validation_epoch_end(self, trainer, pl_module):
self._metric_tracker(trainer, pl_module)
if self.last_step_reached is False:
self._unrolling_steps_handler(pl_module, trainer)
else:
if self.no_improvement_epochs >= self.last_patience:
trainer.should_stop = True
def _metric_tracker(self, trainer, pl_module):
if trainer.callback_metrics.get(self.metric) < self.actual_loss:
self.actual_loss = trainer.callback_metrics.get(self.metric)
self._save_model(pl_module, trainer)
self.no_improvement_epochs = 0
print(f"\nNew best {self.metric}: {self.actual_loss:.4f}")
else:
self.no_improvement_epochs += 1
print(
f"\nNo improvement in {self.metric} for {self.no_improvement_epochs} epochs."
)
def _should_reload_dataloader(self, trainer):
if self.patience is not None:
print(
f"Checking patience: {self.no_improvement_epochs} / {self.patience}"
)
if self.no_improvement_epochs >= self.patience:
return True
elif (
trainer.current_epoch + 1 % self.increase_unrolling_steps_every == 0
):
print("Reached scheduled epoch for increasing unrolling steps.")
return True
return False
def _unrolling_steps_handler(self, pl_module, trainer):
if self._should_reload_dataloader(trainer):
self._load_model(pl_module)
if pl_module.unrolling_steps >= self.max_unrolling_steps:
return
pl_module.unrolling_steps += self.increase_unrolling_steps_by
trainer.datamodule.unrolling_steps = pl_module.unrolling_steps
print(f"Incremented unrolling steps to {pl_module.unrolling_steps}")
trainer.datamodule.setup(stage="fit")
trainer.manual_dataloader_reload()
self.actual_loss = torch.inf
if pl_module.unrolling_steps >= self.max_unrolling_steps:
print(
"Reached max unrolling steps. Stopping further increments."
)
self.last_step_reached = True
def _save_model(self, pl_module, trainer):
pt_path = os.path.join(
self.ckpt_path,
f"{pl_module.unrolling_steps}_unrolling_best_model.pt",
)
torch.save(pl_module.state_dict(), pt_path) # <--- CHANGED THIS
ckpt_path = os.path.join(
self.ckpt_path,
f"{pl_module.unrolling_steps}_unrolling_best_checkpoint.ckpt",
)
trainer.save_checkpoint(ckpt_path, weights_only=False)
def _load_model(self, pl_module):
pt_path = os.path.join(
self.ckpt_path,
f"{pl_module.unrolling_steps}_unrolling_best_model.pt",
)
pl_module.load_state_dict(torch.load(pt_path, weights_only=True))
print(
f"Loaded model weights from {pt_path} for unrolling steps = {pl_module.unrolling_steps}"
)