105 lines
4.0 KiB
Python
105 lines
4.0 KiB
Python
import torch
|
|
from lightning.pytorch.callbacks import Callback
|
|
import os
|
|
|
|
|
|
class SwitchDataLoaderCallback(Callback):
|
|
def __init__(
|
|
self,
|
|
ckpt_path,
|
|
increase_unrolling_steps_by,
|
|
increase_unrolling_steps_every,
|
|
max_unrolling_steps=10,
|
|
patience=None,
|
|
last_patience=None,
|
|
metric="val/loss",
|
|
):
|
|
super().__init__()
|
|
self.ckpt_path = ckpt_path
|
|
if os.path.exists(ckpt_path) is False:
|
|
os.makedirs(ckpt_path)
|
|
self.increase_unrolling_steps_by = increase_unrolling_steps_by
|
|
self.increase_unrolling_steps_every = increase_unrolling_steps_every
|
|
self.max_unrolling_steps = max_unrolling_steps
|
|
self.metric = metric
|
|
self.actual_loss = torch.inf
|
|
if patience is not None:
|
|
self.patience = patience
|
|
if last_patience is not None:
|
|
self.last_patience = last_patience
|
|
self.no_improvement_epochs = 0
|
|
self.last_step_reached = False
|
|
|
|
def on_validation_epoch_end(self, trainer, pl_module):
|
|
self._metric_tracker(trainer, pl_module)
|
|
if self.last_step_reached is False:
|
|
self._unrolling_steps_handler(pl_module, trainer)
|
|
else:
|
|
if self.no_improvement_epochs >= self.last_patience:
|
|
trainer.should_stop = True
|
|
|
|
def _metric_tracker(self, trainer, pl_module):
|
|
if trainer.callback_metrics.get(self.metric) < self.actual_loss:
|
|
self.actual_loss = trainer.callback_metrics.get(self.metric)
|
|
self._save_model(pl_module, trainer)
|
|
self.no_improvement_epochs = 0
|
|
print(f"\nNew best {self.metric}: {self.actual_loss:.4f}")
|
|
else:
|
|
self.no_improvement_epochs += 1
|
|
print(
|
|
f"\nNo improvement in {self.metric} for {self.no_improvement_epochs} epochs."
|
|
)
|
|
|
|
def _should_reload_dataloader(self, trainer):
|
|
if self.patience is not None:
|
|
print(
|
|
f"Checking patience: {self.no_improvement_epochs} / {self.patience}"
|
|
)
|
|
if self.no_improvement_epochs >= self.patience:
|
|
return True
|
|
elif (
|
|
trainer.current_epoch + 1 % self.increase_unrolling_steps_every == 0
|
|
):
|
|
print("Reached scheduled epoch for increasing unrolling steps.")
|
|
return True
|
|
return False
|
|
|
|
def _unrolling_steps_handler(self, pl_module, trainer):
|
|
if self._should_reload_dataloader(trainer):
|
|
self._load_model(pl_module)
|
|
if pl_module.unrolling_steps >= self.max_unrolling_steps:
|
|
return
|
|
pl_module.unrolling_steps += self.increase_unrolling_steps_by
|
|
trainer.datamodule.unrolling_steps = pl_module.unrolling_steps
|
|
print(f"Incremented unrolling steps to {pl_module.unrolling_steps}")
|
|
trainer.datamodule.setup(stage="fit")
|
|
trainer.manual_dataloader_reload()
|
|
self.actual_loss = torch.inf
|
|
if pl_module.unrolling_steps >= self.max_unrolling_steps:
|
|
print(
|
|
"Reached max unrolling steps. Stopping further increments."
|
|
)
|
|
self.last_step_reached = True
|
|
|
|
def _save_model(self, pl_module, trainer):
|
|
pt_path = os.path.join(
|
|
self.ckpt_path,
|
|
f"{pl_module.unrolling_steps}_unrolling_best_model.pt",
|
|
)
|
|
torch.save(pl_module.state_dict(), pt_path) # <--- CHANGED THIS
|
|
ckpt_path = os.path.join(
|
|
self.ckpt_path,
|
|
f"{pl_module.unrolling_steps}_unrolling_best_checkpoint.ckpt",
|
|
)
|
|
trainer.save_checkpoint(ckpt_path, weights_only=False)
|
|
|
|
def _load_model(self, pl_module):
|
|
pt_path = os.path.join(
|
|
self.ckpt_path,
|
|
f"{pl_module.unrolling_steps}_unrolling_best_model.pt",
|
|
)
|
|
pl_module.load_state_dict(torch.load(pt_path, weights_only=True))
|
|
print(
|
|
f"Loaded model weights from {pt_path} for unrolling steps = {pl_module.unrolling_steps}"
|
|
)
|