mnist_sandbox package
Submodules
mnist_sandbox.eval module
- mnist_sandbox.eval.model_evaluate(model, test_loader)[source]
Evaluate model with specified test dataset.
Evaluation without torch.grad calculations.
- Parameters:
model (pl.LightningModule) – Trained model in inference mode.
test_loader (DataLoader) – Test dataset to be evaluated on.
- Returns:
correct (int) – Count of correct predictions.
total (int) – Count of all items.
- Return type:
Tuple[int, int, ndarray]
mnist_sandbox.model module
- class mnist_sandbox.model.MNISTNet(*args, **kwargs)[source]
Bases:
LightningModule
MNIST CNN
- Model architecture
Layer
Name
Type
Params
0
conv1
Conv2d
320
1
conv2
Conv2d
18.5K
2
conv3
Conv2d
55.4K
3
fc1
Linear
442K
4
fc2
Linear
65.7K
5
fc3
Linear
1.3K
584 K Trainable params
0 Non-trainable params
584 K Total params
2.336 Total estimated model params size (MB)
Methods
__call__
(*args, **kwargs)Optimizer and learning rate schedulers
forward
(x)Model process
Model training epoch end
Model validation epoch end
training_step
(batch, batch_idx)Model training step
validation_step
(batch, batch_idx)Model validation step
- configure_optimizers()[source]
Optimizer and learning rate schedulers
- Returns:
output – Tuple with Adam optimizer and ReduceLROnPlateau scheduler
- Return type:
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]
- forward(x)[source]
Model process
- Parameters:
x (torch.Tensor) – input data of size [batch_size, 3, 28, 28]
- Returns:
x – Tensor of size [batch_size, 10] of probabilities logits
- Return type:
torch.Tensor
- on_train_epoch_end()[source]
Model training epoch end
Calculate average loss & accuracy scores;
Log them & pring.
- on_validation_epoch_end()[source]
Model validation epoch end
Calculate average loss & accuracy scores;
Log them & pring.
mnist_sandbox.train module
- mnist_sandbox.train.train_m(model, n_epochs, train_loader, val_loader, save=False, logging_url='file:./.logs/my-mlflow-logs')[source]
Training model on train_loader and validation part with val_loader.
Save trained model weights into sota_mnist_cnn.pth
- Parameters:
model (pl.LightningModule) – Model to be trained.
n_epochs (int) – Count of training epochs.
train_loader (DataLoader) – Training dataset.
val_loader (DataLoader) – Validation dataset.
save (bool) – Save result model. Default is False, because DVC is used
logging_url (str) – Where to store mlflow logs
- Returns:
model – Trained model.
- Return type:
pl.LightningModule