Spend more time on research, less on engineering. It is fully flexible to fit any use case and built on pure PyTorch so there is no need to learn a new language. A quick refactor will allow you to:
import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data import random_split from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl