Updates to tutorial and run post codacy changes
This commit is contained in:
committed by
Nicola Demo
parent
9e55746546
commit
b38b0894b1
124
tutorials/tutorial4/tutorial.ipynb
vendored
124
tutorials/tutorial4/tutorial.ipynb
vendored
@@ -45,13 +45,16 @@
|
||||
"import torch \n",
|
||||
"import matplotlib.pyplot as plt \n",
|
||||
"import torchvision # for MNIST dataset\n",
|
||||
"import warnings\n",
|
||||
"\n",
|
||||
"from pina.problem import AbstractProblem\n",
|
||||
"from pina.solver import SupervisedSolver\n",
|
||||
"from pina.trainer import Trainer\n",
|
||||
"from pina import Condition, LabelTensor\n",
|
||||
"from pina.model.block import ContinuousConvBlock \n",
|
||||
"from pina.model import FeedForward # for building AE and MNIST classification"
|
||||
"from pina.model import FeedForward # for building AE and MNIST classification\n",
|
||||
"\n",
|
||||
"warnings.filterwarnings('ignore')"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -210,15 +213,7 @@
|
||||
"execution_count": 3,
|
||||
"id": "b78c08b8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/matte_b/.local/lib/python3.12/site-packages/torch/functional.py: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# filter dim\n",
|
||||
"filter_dim = [0.1, 0.1]\n",
|
||||
@@ -352,7 +347,105 @@
|
||||
"execution_count": 7,
|
||||
"id": "6d816e7a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
|
||||
"Failed to download (trying next):\n",
|
||||
"HTTP Error 404: Not Found\n",
|
||||
"\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 9.91M/9.91M [00:02<00:00, 3.62MB/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
|
||||
"\n",
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
|
||||
"Failed to download (trying next):\n",
|
||||
"HTTP Error 404: Not Found\n",
|
||||
"\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 28.9k/28.9k [00:00<00:00, 210kB/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
|
||||
"\n",
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
|
||||
"Failed to download (trying next):\n",
|
||||
"HTTP Error 404: Not Found\n",
|
||||
"\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1.65M/1.65M [00:00<00:00, 1.97MB/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
|
||||
"\n",
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
|
||||
"Failed to download (trying next):\n",
|
||||
"HTTP Error 404: Not Found\n",
|
||||
"\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
|
||||
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 4.54k/4.54k [00:00<00:00, 2.55MB/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader, SubsetRandomSampler\n",
|
||||
"\n",
|
||||
@@ -793,7 +886,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 149: 100%|██████████| 1/1 [00:01<00:00, 0.77it/s, v_num=21, data_loss=0.0341, val_loss=0.0341, train_loss=0.0341]"
|
||||
"Epoch 149: 100%|██████████| 1/1 [00:01<00:00, 0.59it/s, v_num=25, data_loss=0.0341, train_loss=0.0341]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -807,7 +900,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 149: 100%|██████████| 1/1 [00:01<00:00, 0.76it/s, v_num=21, data_loss=0.0341, val_loss=0.0341, train_loss=0.0341]\n"
|
||||
"Epoch 149: 100%|██████████| 1/1 [00:01<00:00, 0.58it/s, v_num=25, data_loss=0.0341, train_loss=0.0341]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -823,7 +916,10 @@
|
||||
"solver = SupervisedSolver(problem=CircleProblem(), model=net, loss=torch.nn.MSELoss(), use_lt=True) \n",
|
||||
"\n",
|
||||
"# train\n",
|
||||
"trainer = Trainer(solver, max_epochs=150, accelerator='cpu', enable_model_summary=False) # we train on CPU and avoid model summary at beginning of training (optional)\n",
|
||||
"trainer = Trainer(solver, max_epochs=150, accelerator='cpu', enable_model_summary=False, # we train on CPU and avoid model summary at beginning of training (optional)\n",
|
||||
" train_size=1.0,\n",
|
||||
" val_size=0.0,\n",
|
||||
" test_size=0.0)\n",
|
||||
"trainer.train()\n",
|
||||
" "
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user