Source code for mnist_sandbox.train

import subprocess
from pathlib import Path

import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import MLFlowLogger
from torch.utils.data import DataLoader


[docs] def train_m( model: pl.LightningModule, n_epochs: int, train_loader: DataLoader, val_loader: DataLoader, save: bool = False, logging_url: str = "file:./.logs/my-mlflow-logs", ) -> pl.LightningModule: """Training model on train_loader and validation part with val_loader. * Save trained model weights into sota_mnist_cnn.pth Parameters ---------- model: pl.LightningModule Model to be trained. n_epochs: int Count of training epochs. train_loader: DataLoader Training dataset. val_loader: DataLoader Validation dataset. save: bool Save result model. Default is False, because DVC is used logging_url: str Where to store mlflow logs Returns ------- model: pl.LightningModule Trained model. """ # Git commit git_commit_id = ( subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) .strip() .decode("utf-8") ) logger = MLFlowLogger(experiment_name=git_commit_id, tracking_uri=logging_url) logger.log_hyperparams(dict(model.hparams).update({"git_commit_id": git_commit_id})) trainer = pl.Trainer( accelerator="cpu", devices=1, max_epochs=n_epochs, logger=logger, enable_checkpointing=False, ) trainer.fit(model, train_loader, val_loader) if save: if not Path("models").exists(): Path("models").mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), Path("models") / "sota_mnist_cnn.pth") weights_path = Path.cwd().joinpath() / "models" / "sota_mnist_cnn.pth" print(f"Model saved into {weights_path}") else: print("Model stored in DVC, so it's not saving.") return model