Use Weights & Biases for machine learning experiment tracking, model checkpointing, and collaboration with your team. See the full Weights & Biases Documentation here
With just a few lines of code, you'll get rich, interactive, shareable dashboards which you can see yourself here.
We take security very seriously, and our cloud-hosted dashboard uses industry standard best practices for encryption. If you're working with models and datasets that cannot leave your enterprise cluster, we have on-prem installations available.
It's also easy to download all your data and export it to other tools β like custom analysis in a Jupyter notebook. Here's more on our API.
wandb
library and loginStart by installing the library and logging in to your free account.
!pip install wandb -qU
[2K [90mββββββββββββββββββββββββββββββββββββββββ[0m [32m2.1/2.1 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m [2K [90mββββββββββββββββββββββββββββββββββββββ[0m [32m188.6/188.6 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m [2K [90mββββββββββββββββββββββββββββββββββββββ[0m [32m218.8/218.8 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m [?25h Preparing metadata (setup.py) ... [?25l[?25hdone [2K [90mββββββββββββββββββββββββββββββββββββββββ[0m [32m62.7/62.7 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m [?25h Building wheel for pathtools (setup.py) ... [?25l[?25hdone
# Log in to your W&B account
import wandb
wandb.login()
<IPython.core.display.Javascript object>
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
True
1οΈβ£. Start a new run and pass in hyperparameters to track
2οΈβ£. Log metrics from training or evaluation
3οΈβ£. Visualize results in the dashboard
import random
# Launch 5 simulated experiments
total_runs = 5
for run in range(total_runs):
# π 1οΈβ£ Start a new run to track this script
wandb.init(
# Set the project where this run will be logged
project="basic-intro",
# We pass a run name (otherwise itβll be randomly assigned, like sunshine-lollypop-10)
name=f"experiment_{run}",
# Track hyperparameters and run metadata
config={
"learning_rate": 0.02,
"architecture": "CNN",
"dataset": "CIFAR-100",
"epochs": 10,
})
# This simple block simulates a training loop logging metrics
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset
# π 2οΈβ£ Log metrics from your script to W&B
wandb.log({"acc": acc, "loss": loss})
# Mark the run as finished
wandb.finish()
[34m[1mwandb[0m: Currently logged in as: [33mwatanabe3tipapa[0m. Use [1m`wandb login --relogin`[0m to force relogin
/content/wandb/run-20230906_193729-xvf2dqjq
acc | ββββββββ |
loss | ββββββββ |
acc | 0.77788 |
loss | 0.1777 |
./wandb/run-20230906_193729-xvf2dqjq/logs
/content/wandb/run-20230906_193733-8wb44y5g
acc | ββββ ββββ |
loss | ββ β βββββ |
acc | 0.8886 |
loss | 0.09287 |
./wandb/run-20230906_193733-8wb44y5g/logs
/content/wandb/run-20230906_193738-egzxh12p
acc | ββββββββ |
loss | ββ ββββββ |
acc | 0.77004 |
loss | 0.26834 |
./wandb/run-20230906_193738-egzxh12p/logs
/content/wandb/run-20230906_193742-t0cbhftk
acc | ββββββββ |
loss | ββ ββββββ |
acc | 0.93559 |
loss | 0.10872 |
./wandb/run-20230906_193742-t0cbhftk/logs
/content/wandb/run-20230906_193746-sxvlcv28
acc | ββββββββ |
loss | ββ ββββββ |
acc | 0.80085 |
loss | 0.16985 |
./wandb/run-20230906_193746-sxvlcv28/logs
3οΈβ£ You can find your interactive dashboard by clicking any of the π wandb links above.
πͺ Run this model to train a simple MNIST classifier, and click on the project page link to see your results stream in live to a W&B project.
Any run in wandb
automatically logs metrics,
system information,
hyperparameters,
terminal output and
you'll see an interactive table
with model inputs and outputs.
#@title
import wandb
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def get_dataloader(is_train, batch_size, slice=5):
"Get a training dataloader"
full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
loader = torch.utils.data.DataLoader(dataset=sub_dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
pin_memory=True, num_workers=2)
return loader
def get_model(dropout):
"A simple model"
model = nn.Sequential(nn.Flatten(),
nn.Linear(28*28, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(256,10)).to(device)
return model
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
"Compute performance of the model on the validation dataset and log a wandb.Table"
model.eval()
val_loss = 0.
with torch.inference_mode():
correct = 0
for i, (images, labels) in enumerate(valid_dl):
images, labels = images.to(device), labels.to(device)
# Forward pass β‘
outputs = model(images)
val_loss += loss_func(outputs, labels)*labels.size(0)
# Compute accuracy and accumulate
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
# Log one batch of images to the dashboard, always same batch_idx.
if i==batch_idx and log_images:
log_image_table(images, predicted, labels, outputs.softmax(dim=1))
return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)
def log_image_table(images, predicted, labels, probs):
"Log a wandb.Table with (img, pred, target, scores)"
# π Create a wandb Table to log images, labels and predictions to
table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
wandb.log({"predictions_table":table}, commit=False)
# Launch 5 experiments, trying different dropout rates
for _ in range(5):
# π initialise a wandb run
wandb.init(
project="pytorch-intro",
config={
"epochs": 10,
"batch_size": 128,
"lr": 1e-3,
"dropout": random.uniform(0.01, 0.80),
})
# Copy your config
config = wandb.config
# Get the data
train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
valid_dl = get_dataloader(is_train=False, batch_size=2*config.batch_size)
n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
# A simple MLP model
model = get_model(config.dropout)
# Make the loss and optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
# Training
example_ct = 0
step_ct = 0
for epoch in range(config.epochs):
model.train()
for step, (images, labels) in enumerate(train_dl):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
train_loss = loss_func(outputs, labels)
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
example_ct += len(images)
metrics = {"train/train_loss": train_loss,
"train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
"train/example_ct": example_ct}
if step + 1 < n_steps_per_epoch:
# π Log train metrics to wandb
wandb.log(metrics)
step_ct += 1
val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))
# π Log train and validation metrics to wandb
val_metrics = {"val/val_loss": val_loss,
"val/val_accuracy": accuracy}
wandb.log({**metrics, **val_metrics})
print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")
# If you had a test set, this is how you could log it as a Summary metric
wandb.summary['test_accuracy'] = 0.8
# π Close your wandb run
wandb.finish()
/content/wandb/run-20230906_193821-8eg638zh
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
100%|ββββββββββ| 9912422/9912422 [00:00<00:00, 36992851.80it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
100%|ββββββββββ| 28881/28881 [00:00<00:00, 9455600.17it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
100%|ββββββββββ| 1648877/1648877 [00:00<00:00, 31215798.60it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|ββββββββββ| 4542/4542 [00:00<00:00, 4807097.85it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw Train Loss: 0.220, Valid Loss: 0.283760, Accuracy: 0.92 Train Loss: 0.167, Valid Loss: 0.227055, Accuracy: 0.93 Train Loss: 0.136, Valid Loss: 0.195836, Accuracy: 0.94 Train Loss: 0.226, Valid Loss: 0.177382, Accuracy: 0.94 Train Loss: 0.037, Valid Loss: 0.168279, Accuracy: 0.95 Train Loss: 0.071, Valid Loss: 0.155267, Accuracy: 0.95 Train Loss: 0.042, Valid Loss: 0.163625, Accuracy: 0.95 Train Loss: 0.054, Valid Loss: 0.157482, Accuracy: 0.95 Train Loss: 0.031, Valid Loss: 0.148917, Accuracy: 0.95 Train Loss: 0.016, Valid Loss: 0.163321, Accuracy: 0.95
train/epoch | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/example_ct | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/train_loss | ββ ββββββββββββββββββββββββββββββββββββββ |
val/val_accuracy | ββββββββββ |
val/val_loss | ββ ββββββββ |
test_accuracy | 0.8 |
train/epoch | 10.0 |
train/example_ct | 120000 |
train/train_loss | 0.01581 |
val/val_accuracy | 0.954 |
val/val_loss | 0.16332 |
./wandb/run-20230906_193821-8eg638zh/logs
/content/wandb/run-20230906_193915-vqldhjme
Train Loss: 0.279, Valid Loss: 0.285735, Accuracy: 0.92 Train Loss: 0.159, Valid Loss: 0.230620, Accuracy: 0.93 Train Loss: 0.203, Valid Loss: 0.200515, Accuracy: 0.93 Train Loss: 0.102, Valid Loss: 0.189465, Accuracy: 0.94 Train Loss: 0.076, Valid Loss: 0.182065, Accuracy: 0.94 Train Loss: 0.100, Valid Loss: 0.155006, Accuracy: 0.95 Train Loss: 0.074, Valid Loss: 0.152520, Accuracy: 0.96 Train Loss: 0.072, Valid Loss: 0.160035, Accuracy: 0.95 Train Loss: 0.041, Valid Loss: 0.152959, Accuracy: 0.95 Train Loss: 0.041, Valid Loss: 0.147800, Accuracy: 0.95
train/epoch | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/example_ct | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/train_loss | ββ ββββββββββββββββββββββββββββββββββββββ |
val/val_accuracy | ββββββββββ |
val/val_loss | ββ ββββββββ |
test_accuracy | 0.8 |
train/epoch | 10.0 |
train/example_ct | 120000 |
train/train_loss | 0.0412 |
val/val_accuracy | 0.9525 |
val/val_loss | 0.1478 |
./wandb/run-20230906_193915-vqldhjme/logs
/content/wandb/run-20230906_193958-i9kpll0a
Train Loss: 0.614, Valid Loss: 0.348776, Accuracy: 0.90 Train Loss: 0.465, Valid Loss: 0.279722, Accuracy: 0.92 Train Loss: 0.520, Valid Loss: 0.259074, Accuracy: 0.92 Train Loss: 0.365, Valid Loss: 0.242544, Accuracy: 0.92 Train Loss: 0.270, Valid Loss: 0.229999, Accuracy: 0.93 Train Loss: 0.322, Valid Loss: 0.219416, Accuracy: 0.94 Train Loss: 0.261, Valid Loss: 0.209230, Accuracy: 0.93 Train Loss: 0.340, Valid Loss: 0.202012, Accuracy: 0.94 Train Loss: 0.297, Valid Loss: 0.202392, Accuracy: 0.94 Train Loss: 0.151, Valid Loss: 0.191447, Accuracy: 0.94
train/epoch | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/example_ct | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/train_loss | ββ ββββββββββββββββββββββββββββββββββββββ |
val/val_accuracy | βββ β ββββββ |
val/val_loss | ββ ββββββββ |
test_accuracy | 0.8 |
train/epoch | 10.0 |
train/example_ct | 120000 |
train/train_loss | 0.15076 |
val/val_accuracy | 0.939 |
val/val_loss | 0.19145 |
./wandb/run-20230906_193958-i9kpll0a/logs
/content/wandb/run-20230906_194040-4z1eb2f6
Train Loss: 0.275, Valid Loss: 0.288655, Accuracy: 0.92 Train Loss: 0.327, Valid Loss: 0.226741, Accuracy: 0.93 Train Loss: 0.209, Valid Loss: 0.199396, Accuracy: 0.94 Train Loss: 0.188, Valid Loss: 0.176084, Accuracy: 0.94 Train Loss: 0.125, Valid Loss: 0.164420, Accuracy: 0.95 Train Loss: 0.071, Valid Loss: 0.157538, Accuracy: 0.94 Train Loss: 0.124, Valid Loss: 0.156629, Accuracy: 0.95 Train Loss: 0.034, Valid Loss: 0.149043, Accuracy: 0.95 Train Loss: 0.041, Valid Loss: 0.146938, Accuracy: 0.95 Train Loss: 0.025, Valid Loss: 0.157761, Accuracy: 0.95
train/epoch | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/example_ct | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/train_loss | ββββββββββββββββββββββββββββββββββββββββ |
val/val_accuracy | βββ βββββββ |
val/val_loss | ββ ββββββββ |
test_accuracy | 0.8 |
train/epoch | 10.0 |
train/example_ct | 120000 |
train/train_loss | 0.02474 |
val/val_accuracy | 0.9505 |
val/val_loss | 0.15776 |
./wandb/run-20230906_194040-4z1eb2f6/logs
/content/wandb/run-20230906_194123-3il5ut3r
Train Loss: 0.288, Valid Loss: 0.275766, Accuracy: 0.92 Train Loss: 0.260, Valid Loss: 0.217162, Accuracy: 0.93 Train Loss: 0.150, Valid Loss: 0.194545, Accuracy: 0.94 Train Loss: 0.074, Valid Loss: 0.175806, Accuracy: 0.94 Train Loss: 0.050, Valid Loss: 0.173306, Accuracy: 0.94 Train Loss: 0.066, Valid Loss: 0.159733, Accuracy: 0.95 Train Loss: 0.049, Valid Loss: 0.154360, Accuracy: 0.95 Train Loss: 0.036, Valid Loss: 0.153197, Accuracy: 0.95 Train Loss: 0.033, Valid Loss: 0.153414, Accuracy: 0.96 Train Loss: 0.010, Valid Loss: 0.151329, Accuracy: 0.95
train/epoch | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/example_ct | βββββββββββββββββββββ β β β β β ββββββββββββββ |
train/train_loss | ββββββββββββββββββββββββββββββββββββββββ |
val/val_accuracy | ββββββββββ |
val/val_loss | ββ ββββββββ |
test_accuracy | 0.8 |
train/epoch | 10.0 |
train/example_ct | 120000 |
train/train_loss | 0.00975 |
val/val_accuracy | 0.954 |
val/val_loss | 0.15133 |
./wandb/run-20230906_194123-3il5ut3r/logs
You have now trained your first model using wandb! π Click on the wandb link above to see your metrics
W&B Alerts allows you to send alerts, triggered from your Python code, to your Slack or email. There are 2 steps to follow the first time you'd like to send a Slack or email alert, triggered from your code:
1) Turn on Alerts in your W&B User Settings
2) Add wandb.alert()
to your code:
wandb.alert(
title="Low accuracy",
text=f"Accuracy is below the acceptable threshold"
)
See the minimal example below to see how to use wandb.alert
. You can find the full docs for W&B Alerts here
# Start a wandb run
wandb.init(project="pytorch-intro")
# Simulating a model training loop
acc_threshold = 0.3
for training_step in range(1000):
# Generate a random number for accuracy
accuracy = round(random.random() + random.random(), 3)
print(f'Accuracy is: {accuracy}, {acc_threshold}')
# π Log accuracy to wandb
wandb.log({"Accuracy": accuracy})
# π If the accuracy is below the threshold, fire a W&B Alert and stop the run
if accuracy <= acc_threshold:
# π Send the wandb Alert
wandb.alert(
title='Low Accuracy',
text=f'Accuracy {accuracy} at step {training_step} is below the acceptable theshold, {acc_threshold}',
)
print('Alert triggered')
break
# Mark the run as finished (useful in Jupyter notebooks)
wandb.finish()
/content/wandb/run-20230906_194425-iryb3ivh
Accuracy is: 0.45, 0.3 Accuracy is: 1.281, 0.3 Accuracy is: 1.398, 0.3 Accuracy is: 1.635, 0.3 Accuracy is: 0.77, 0.3 Accuracy is: 1.819, 0.3 Accuracy is: 0.933, 0.3 Accuracy is: 1.124, 0.3 Accuracy is: 1.071, 0.3 Accuracy is: 0.899, 0.3 Accuracy is: 0.945, 0.3 Accuracy is: 0.96, 0.3 Accuracy is: 1.259, 0.3 Accuracy is: 0.468, 0.3 Accuracy is: 1.057, 0.3 Accuracy is: 1.086, 0.3 Accuracy is: 1.22, 0.3 Accuracy is: 1.578, 0.3 Accuracy is: 0.793, 0.3 Accuracy is: 0.501, 0.3 Accuracy is: 1.171, 0.3 Accuracy is: 1.355, 0.3 Accuracy is: 1.31, 0.3 Accuracy is: 1.198, 0.3 Accuracy is: 1.316, 0.3 Accuracy is: 1.068, 0.3 Accuracy is: 1.536, 0.3 Accuracy is: 0.536, 0.3 Accuracy is: 0.815, 0.3 Accuracy is: 0.501, 0.3 Accuracy is: 0.642, 0.3 Accuracy is: 1.214, 0.3 Accuracy is: 1.404, 0.3 Accuracy is: 1.354, 0.3 Accuracy is: 1.523, 0.3 Accuracy is: 1.244, 0.3 Accuracy is: 0.566, 0.3 Accuracy is: 0.084, 0.3 Alert triggered
VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, maxβ¦
Accuracy | ββββββββ β βββ βββ β βββββ βββ ββ ββββββββββββ |
Accuracy | 0.084 |
./wandb/run-20230906_194425-iryb3ivh/logs
The next tutorial you will learn how to do hyperparameter optimization using W&B Sweeps: