urbanlc.model.train_utils.load_checkpoint#

urbanlc.model.train_utils.load_checkpoint(checkpoint_path: str, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, device: device = device(type='cuda')) Tuple[Module, Optimizer, _LRScheduler, int]#

Load model checkpoint, including model state, optimizer state, scheduler state, and epoch.

This function loads a model checkpoint from a file and returns the loaded model, optimizer, scheduler, and epoch.

Parameters:
  • checkpoint_path (str) – Path to the checkpoint file.

  • model (nn.Module) – PyTorch model to be loaded.

  • optimizer (Optional[torch.optim.Optimizer]) – Optimizer to be loaded. Defaults to None.

  • scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]) – Learning rate scheduler to be loaded. Defaults to None.

  • device (torch.device) – Device to load the checkpoint. Defaults to GPU if available.

Returns:

Loaded model, optimizer, scheduler, and epoch.

Return type:

Tuple[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]