import torch from lightning.pytorch.callbacks import Callback import os class SwitchDataLoaderCallback(Callback): def __init__( self, ckpt_path, increase_unrolling_steps_by, increase_unrolling_steps_every, max_unrolling_steps=10, patience=None, last_patience=None, metric="val/loss", ): super().__init__() self.ckpt_path = ckpt_path if os.path.exists(ckpt_path) is False: os.makedirs(ckpt_path) self.increase_unrolling_steps_by = increase_unrolling_steps_by self.increase_unrolling_steps_every = increase_unrolling_steps_every self.max_unrolling_steps = max_unrolling_steps self.metric = metric self.actual_loss = torch.inf if patience is not None: self.patience = patience if last_patience is not None: self.last_patience = last_patience self.no_improvement_epochs = 0 self.last_step_reached = False def on_validation_epoch_end(self, trainer, pl_module): self._metric_tracker(trainer, pl_module) if self.last_step_reached is False: self._unrolling_steps_handler(pl_module, trainer) else: if self.no_improvement_epochs >= self.last_patience: trainer.should_stop = True def _metric_tracker(self, trainer, pl_module): if trainer.callback_metrics.get(self.metric) < self.actual_loss: self.actual_loss = trainer.callback_metrics.get(self.metric) self._save_model(pl_module, trainer) self.no_improvement_epochs = 0 print(f"\nNew best {self.metric}: {self.actual_loss:.4f}") else: self.no_improvement_epochs += 1 print( f"\nNo improvement in {self.metric} for {self.no_improvement_epochs} epochs." ) def _should_reload_dataloader(self, trainer): if self.patience is not None: print( f"Checking patience: {self.no_improvement_epochs} / {self.patience}" ) if self.no_improvement_epochs >= self.patience: return True elif ( trainer.current_epoch + 1 % self.increase_unrolling_steps_every == 0 ): print("Reached scheduled epoch for increasing unrolling steps.") return True return False def _unrolling_steps_handler(self, pl_module, trainer): if self._should_reload_dataloader(trainer): self._load_model(pl_module) if pl_module.unrolling_steps >= self.max_unrolling_steps: return pl_module.unrolling_steps += self.increase_unrolling_steps_by trainer.datamodule.unrolling_steps = pl_module.unrolling_steps print(f"Incremented unrolling steps to {pl_module.unrolling_steps}") trainer.datamodule.setup(stage="fit") trainer.manual_dataloader_reload() self.actual_loss = torch.inf if pl_module.unrolling_steps >= self.max_unrolling_steps: print( "Reached max unrolling steps. Stopping further increments." ) self.last_step_reached = True def _save_model(self, pl_module, trainer): pt_path = os.path.join( self.ckpt_path, f"{pl_module.unrolling_steps}_unrolling_best_model.pt", ) torch.save(pl_module.state_dict(), pt_path) # <--- CHANGED THIS ckpt_path = os.path.join( self.ckpt_path, f"{pl_module.unrolling_steps}_unrolling_best_checkpoint.ckpt", ) trainer.save_checkpoint(ckpt_path, weights_only=False) def _load_model(self, pl_module): pt_path = os.path.join( self.ckpt_path, f"{pl_module.unrolling_steps}_unrolling_best_model.pt", ) pl_module.load_state_dict(torch.load(pt_path, weights_only=True)) print( f"Loaded model weights from {pt_path} for unrolling steps = {pl_module.unrolling_steps}" )