Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
pytorch-lightning基本信息
pytorch-lightning介绍
安装 Lightning
从 PyPI 简单安装
pip install lightning
其他安装选项
安装带可选依赖项的版本
pip install lightning['extra']
Conda 安装
conda install lightning -c conda-forge
安装稳定版本
从源码安装未来发布的稳定版本
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U
安装最新版本
从源码安装 nightly 版本(不保证稳定性)
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
或从测试 PyPI 安装
pip install -iU https://test.pypi.org/simple/ pytorch-lightning
Lightning 有两个核心包
PyTorch Lightning: 大规模训练和部署 PyTorch。
Lightning Fabric: 专家控制。
Lightning 让你可以精细控制在 PyTorch 上增加的抽象层次。
PyTorch Lightning: 大规模训练和部署 PyTorch
PyTorch Lightning 只是更有组织的 PyTorch - Lightning 解耦了 PyTorch 代码,将科学与工程分离。
示例
探索使用 PyTorch Lightning 进行各种类型的训练。预训练和微调任何类型的模型,以执行分类、分割、摘要等任务:
任务 | 描述 | 运行 |
---|---|---|
Hello world | 预训练 - Hello world 示例 | |
图像分割 | 微调 - 使用 ResNet-50 模型进行图像分割 | |
文本分类 | 微调 - 文本分类器(BERT 模型) | |
文本摘要 | 微调 - 文本摘要(Hugging Face transformer 模型) | |
音频生成 | 微调 - 音频生成(transformer 模型) |
Hello 简单模型
# main.py
# ! pip install torchvision
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L
# --------------------------------
# 步骤1: 定义一个 LightningModule
# --------------------------------
# 一个 LightningModule(nn.Module 的子类)定义了一个完整的*系统*
# (例如:一个 LLM、扩散模型、自编码器,或一个简单的图像分类器)。
class LitAutoEncoder(L.LightningModule):
def __init__(
self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
# 在 lightning 中,forward 定义了预测/推理行为
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step 定义了训练循环。它独立于 forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# -------------------
# 步骤2: 定义数据
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])
# -------------------
# 步骤3: 训练
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer()
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))
在终端运行模型
pip install torchvision
python main.py
高级功能
Lightning 具有超过40+高级功能,专为大规模专业AI研究设计。
以下是一些示例:
在成千上万的GPU上训练而无需更改代码
# 8 个 GPU
# 无需代码更改
trainer = Trainer(accelerator="gpu", devices=8)
# 256 个 GPU
trainer = Trainer(accelerator="gpu", devices=8, num_nodes=32)
在其他加速器(如 TPU)上训练而无需更改代码
# 无需代码更改
trainer = Trainer(accelerator="tpu", devices=8)
16位精度
# 无需代码更改
trainer = Trainer(precision=16)
实验管理器
from lightning import loggers
# tensorboard
trainer = Trainer(logger=TensorBoardLogger("logs/"))
# weights and biases
trainer = Trainer(logger=loggers.WandbLogger())
# comet
trainer = Trainer(logger=loggers.CometLogger())
# mlflow
trainer = Trainer(logger=loggers.MLFlowLogger())
# neptune
trainer = Trainer(logger=loggers.NeptuneLogger())
# 还有很多
早停
es = EarlyStopping(monitor="val_loss")
trainer = Trainer(callbacks=[es])
检查点
checkpointing = ModelCheckpoint(monitor="val_loss")
trainer = Trainer(callbacks=[checkpointing])
导出到 torchscript (JIT)(生产用途)
# torchscript
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
导出到 ONNX(生产用途)
# onnx
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
autoencoder = LitAutoEncoder()
input_sample = torch.randn((1, 64))
autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
os.path.isfile(tmpfile.name)
相比未结构化的 PyTorch 的优势
- 模型变得与硬件无关
- 代码清晰易读,因为工程代码被抽象掉了
- 更容易复现
- 减少错误,因为 Lightning 处理了复杂的工程问题
- 保留所有灵活性(LightningModules 仍然是 PyTorch 模块),但去除了大量的样板代码
- Lightning 与流行的机器学习工具有数十种集成。
- 每个新的 PR 都经过严格测试。我们测试每个支持的 PyTorch 和 Python 版本组合,每个操作系统,多 GPU 甚至 TPU。
- 运行速度开销最小(每个 epoch 仅比纯 PyTorch 慢约 300 毫秒)。
Lightning Fabric: 专家控制
在任何设备上、任何规模下运行,拥有对 PyTorch 训练循环和扩展策略的专家级控制。你甚至可以编写自己的 Trainer。
Fabric 专为最复杂的模型设计,如基础模型扩展、LLM、扩散模型、Transformer、强化学习、主动学习等。适用于任何规模。
更改内容 | 结果 Fabric 代码(复制我!) |
---|---|
|
|
主要特性
轻松切换从 CPU 运行到 GPU(Apple Silicon、CUDA、…)、TPU、多 GPU 甚至多节点训练
# 使用你的可用硬件
# 无需代码更改
fabric = Fabric()
# 在 GPU(CUDA 或 MPS)上运行
fabric = Fabric(accelerator="gpu")
# 8 个 GPU
fabric = Fabric(accelerator="gpu", devices=8)
# 256 个 GPU,多节点
fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32)
# 在 TPU 上运行
fabric = Fabric(accelerator="tpu")
开箱即用的最先进的分布式训练策略(DDP、FSDP、DeepSpeed)和混合精度
# 使用最先进的分布式训练技术
fabric = Fabric(strategy="ddp")
fabric = Fabric(strategy="deepspeed")
fabric = Fabric(strategy="fsdp")
# 切换精度
fabric = Fabric(precision="16-mixed")
fabric = Fabric(precision="64")
所有设备逻辑样板代码都为你处理
# 不再需要这些!
- model.to(device)
- batch.to(device)
使用 Fabric 原语构建你自己的自定义 Trainer 进行训练检查点记录、日志记录等
python
import lightning as L
class MyCustomTrainer:
def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)
def fit(self, model, optimizer, dataloader, max_epochs):
self.fabric.launch()
model, optimizer = self.fabric.setup(model, optimizer)
dataloader = self.fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(max_epochs):
for batch in dataloader:
input, target = batch
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
self.fabric.backward(loss)
optimizer.step()
你可以在我们的示例中找到更详细的示例
示例
自监督学习
卷积架构
强化学习
GANs
经典机器学习
持续集成
Lightning 在多个 CPU、GPU 和 TPU 以及主要的 Python 和 PyTorch 版本上进行了严格测试。
*Codecov 超过 90%+,但构建延迟可能显示更少
当前构建状态
系统 / PyTorch 版本 | 1.13 | 2.0 | 2.1 |
---|---|---|---|
Linux py3.9 [GPUs] | |||
Linux py3.9 [TPUs] | |||
Linux (多个 Python 版本) | |||
OSX (多个 Python 版本) | |||
Windows (多个 Python 版本) |
社区
Lightning 社区由以下人员维护
- 10 多名核心贡献者,他们都是来自顶级 AI 实验室的专业工程师、研究科学家和博士生的混合体。
- 800 多名社区贡献者。
想帮助我们构建 Lightning 并减少数千名研究人员的样板代码吗?了解如何在此处进行首次贡献
Lightning 也是PyTorch 生态系统的一部分,这要求项目必须有坚实的测试、文档和支持。
请求帮助
如果你有任何问题,请: