fix module and model + add curriculum callback
This commit is contained in:
104
ThermalSolver/switch_dataloader_callback.py
Normal file
104
ThermalSolver/switch_dataloader_callback.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
import os
|
||||
|
||||
|
||||
class SwitchDataLoaderCallback(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path,
|
||||
increase_unrolling_steps_by,
|
||||
increase_unrolling_steps_every,
|
||||
max_unrolling_steps=10,
|
||||
patience=None,
|
||||
last_patience=None,
|
||||
metric="val/loss",
|
||||
):
|
||||
super().__init__()
|
||||
self.ckpt_path = ckpt_path
|
||||
if os.path.exists(ckpt_path) is False:
|
||||
os.makedirs(ckpt_path)
|
||||
self.increase_unrolling_steps_by = increase_unrolling_steps_by
|
||||
self.increase_unrolling_steps_every = increase_unrolling_steps_every
|
||||
self.max_unrolling_steps = max_unrolling_steps
|
||||
self.metric = metric
|
||||
self.actual_loss = torch.inf
|
||||
if patience is not None:
|
||||
self.patience = patience
|
||||
if last_patience is not None:
|
||||
self.last_patience = last_patience
|
||||
self.no_improvement_epochs = 0
|
||||
self.last_step_reached = False
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
self._metric_tracker(trainer, pl_module)
|
||||
if self.last_step_reached is False:
|
||||
self._unrolling_steps_handler(pl_module, trainer)
|
||||
else:
|
||||
if self.no_improvement_epochs >= self.last_patience:
|
||||
trainer.should_stop = True
|
||||
|
||||
def _metric_tracker(self, trainer, pl_module):
|
||||
if trainer.callback_metrics.get(self.metric) < self.actual_loss:
|
||||
self.actual_loss = trainer.callback_metrics.get(self.metric)
|
||||
self._save_model(pl_module, trainer)
|
||||
self.no_improvement_epochs = 0
|
||||
print(f"\nNew best {self.metric}: {self.actual_loss:.4f}")
|
||||
else:
|
||||
self.no_improvement_epochs += 1
|
||||
print(
|
||||
f"\nNo improvement in {self.metric} for {self.no_improvement_epochs} epochs."
|
||||
)
|
||||
|
||||
def _should_reload_dataloader(self, trainer):
|
||||
if self.patience is not None:
|
||||
print(
|
||||
f"Checking patience: {self.no_improvement_epochs} / {self.patience}"
|
||||
)
|
||||
if self.no_improvement_epochs >= self.patience:
|
||||
return True
|
||||
elif (
|
||||
trainer.current_epoch + 1 % self.increase_unrolling_steps_every == 0
|
||||
):
|
||||
print("Reached scheduled epoch for increasing unrolling steps.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _unrolling_steps_handler(self, pl_module, trainer):
|
||||
if self._should_reload_dataloader(trainer):
|
||||
self._load_model(pl_module)
|
||||
if pl_module.unrolling_steps >= self.max_unrolling_steps:
|
||||
return
|
||||
pl_module.unrolling_steps += self.increase_unrolling_steps_by
|
||||
trainer.datamodule.unrolling_steps = pl_module.unrolling_steps
|
||||
print(f"Incremented unrolling steps to {pl_module.unrolling_steps}")
|
||||
trainer.datamodule.setup(stage="fit")
|
||||
trainer.manual_dataloader_reload()
|
||||
self.actual_loss = torch.inf
|
||||
if pl_module.unrolling_steps >= self.max_unrolling_steps:
|
||||
print(
|
||||
"Reached max unrolling steps. Stopping further increments."
|
||||
)
|
||||
self.last_step_reached = True
|
||||
|
||||
def _save_model(self, pl_module, trainer):
|
||||
pt_path = os.path.join(
|
||||
self.ckpt_path,
|
||||
f"{pl_module.unrolling_steps}_unrolling_best_model.pt",
|
||||
)
|
||||
torch.save(pl_module.state_dict(), pt_path) # <--- CHANGED THIS
|
||||
ckpt_path = os.path.join(
|
||||
self.ckpt_path,
|
||||
f"{pl_module.unrolling_steps}_unrolling_best_checkpoint.ckpt",
|
||||
)
|
||||
trainer.save_checkpoint(ckpt_path, weights_only=False)
|
||||
|
||||
def _load_model(self, pl_module):
|
||||
pt_path = os.path.join(
|
||||
self.ckpt_path,
|
||||
f"{pl_module.unrolling_steps}_unrolling_best_model.pt",
|
||||
)
|
||||
pl_module.load_state_dict(torch.load(pt_path, weights_only=True))
|
||||
print(
|
||||
f"Loaded model weights from {pt_path} for unrolling steps = {pl_module.unrolling_steps}"
|
||||
)
|
||||
Reference in New Issue
Block a user