PyTorch Lightning
DVCLive allows you to add experiment tracking capabilities to your PyTorch Lightning projects.
If you are using Lightning Fabric, check the DVCLive - Lightning Fabric page.
Usage
If you pass the
DVCLiveLogger
to your
Trainer,
DVCLive will automatically log the
metrics
and
parameters
tracked in your
LightningModule.
import lightning.pytorch as pl
from dvclive.lightning import DVCLiveLogger
...
class LitModule(pl.LightningModule):
def __init__(self, layer_1_dim=128, learning_rate=1e-2):
super().__init__()
# layer_1_dim and learning_rate will be logged by DVCLive
self.save_hyperparameters()
def training_step(self, batch, batch_idx):
metric = ...
# See Output Format bellow
self.log("train_metric", metric, on_step=False, on_epoch=True)
dvclive_logger = DVCLiveLogger()
model = LitModule()
trainer = pl.Trainer(logger=dvclive_logger)
trainer.fit(model)By default, PyTorch Lightning creates a directory to store checkpoints using the
logger's name (DVCLiveLogger). You can change the checkpoint path or disable
checkpointing at all as described in the
PyTorch Lightning documentation
Parameters
-
run_name- (Noneby default) - Name of the run, used in PyTorch Lightning to get version. -
prefix- (Noneby default) - string that adds to each metric name. -
log_model- (Falseby default) - uselive.log_artifact()to log checkpoints created byModelCheckpoint. See Log model checkpoints.-
if
log_model == False(default), no checkpoint is logged. -
if
log_model == True, checkpoints are logged at the end of training, except whensave_top_k == -1which logs every checkpoint during training. -
if
log_model == 'all', checkpoints are logged during training.
-
-
experiment- (Noneby default) -Liveobject to be used instead of initializing a new one. -
**kwargs- Any additional arguments will be used to instantiate a newLiveinstance. Ifexperimentis used, the arguments are ignored.
Examples
Log model checkpoints
Use log_model to save the checkpoints (it will use Live.log_artifact()
internally to save those). At the end of training, DVCLive will copy the
best_model_path to the dvclive/artifacts directory and
annotate it with name best (for example, to be consumed in DVC Studio
model registry or automation scenarios).
- Save updates to the checkpoints directory at the end of training:
from dvclive.lightning import DVCLiveLogger
logger = DVCLiveLogger(log_model=True)
trainer = Trainer(logger=logger)
trainer.fit(model)- Save updates to the checkpoints directory whenever a new checkpoint is saved:
from dvclive.lightning import DVCLiveLogger
logger = DVCLiveLogger(log_model="all")
trainer = Trainer(logger=logger)
trainer.fit(model)- Use a custom
ModelCheckpoint:
from dvclive.lightning import DVCLiveLogger
logger = DVCLiveLogger(log_model=True),
checkpoint_callback = ModelCheckpoint(
dirpath="model",
monitor="val_acc",
mode="max",
)
trainer = Trainer(logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model)Passing additional DVCLive arguments
- Using
experimentto pass an existingLiveinstance.
from dvclive import Live
from dvclive.lightning import DVCLiveLogger
with Live("custom_dir") as live:
trainer = Trainer(
logger=DVCLiveLogger(experiment=live))
trainer.fit(model)
# Log additional metrics after training
live.log_metric("summary_metric", 1.0, plot=False)- Using
**kwargsto customizeLive.
from dvclive.lightning import DVCLiveLogger
trainer = Trainer(
logger=DVCLiveLogger(dir='my_logs_dir'))
trainer.fit(model)Output format
Each metric will be logged to:
{Live.plots_dir}/metrics/{split_prefix}/{iter_type}/{metric_name}.tsvWhere:
{Live.plots_dir}is defined inLive.{iter_type}can be eitherepochorstep. This is inferred from theon_stepandon_epocharguments used in thelogcall.{split_prefix}_{metric_name}is the full string passed to thelogcall.split_prefixcan be eithertrain,valortest.
In the example above, the metric logged as:
self.log("train_metric", metric, on_step=False, on_epoch=True)Will be stored in:
dvclive/metrics/train/epoch/metric.tsv