urbanlc.model.train_utils.load_checkpoint#
- urbanlc.model.train_utils.load_checkpoint(checkpoint_path: str, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, device: device = device(type='cuda')) Tuple[Module, Optimizer, _LRScheduler, int] #
Load model checkpoint, including model state, optimizer state, scheduler state, and epoch.
This function loads a model checkpoint from a file and returns the loaded model, optimizer, scheduler, and epoch.
- Parameters:
checkpoint_path (str) – Path to the checkpoint file.
model (nn.Module) – PyTorch model to be loaded.
optimizer (Optional[torch.optim.Optimizer]) – Optimizer to be loaded. Defaults to None.
scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]) – Learning rate scheduler to be loaded. Defaults to None.
device (torch.device) – Device to load the checkpoint. Defaults to GPU if available.
- Returns:
Loaded model, optimizer, scheduler, and epoch.
- Return type:
Tuple[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]