pytorch-lightning 介绍
了解项目的详细信息和使用方法
安装 Lightning
从 PyPI 简单安装
bash1pip install lightning
其他安装选项
安装带可选依赖项的版本
bash1pip install lightning['extra']
Conda 安装
bash1conda install lightning -c conda-forge
安装稳定版本
从源码安装未来发布的稳定版本
bash1pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U
安装最新版本
从源码安装 nightly 版本(不保证稳定性)
bash1pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
或从测试 PyPI 安装
bash1pip install -iU https://test.pypi.org/simple/ pytorch-lightning
Lightning 有两个核心包
PyTorch Lightning: 大规模训练和部署 PyTorch。
Lightning Fabric: 专家控制。
Lightning 让你可以精细控制在 PyTorch 上增加的抽象层次。
PyTorch Lightning: 大规模训练和部署 PyTorch
PyTorch Lightning 只是更有组织的 PyTorch - Lightning 解耦了 PyTorch 代码,将科学与工程分离。
示例
探索使用 PyTorch Lightning 进行各种类型的训练。预训练和微调任何类型的模型,以执行分类、分割、摘要等任务:
任务 | 描述 | 运行 |
---|---|---|
Hello world | 预训练 - Hello world 示例 | |
图像分割 | 微调 - 使用 ResNet-50 模型进行图像分割 | |
文本分类 | 微调 - 文本分类器(BERT 模型) | |
文本摘要 | 微调 - 文本摘要(Hugging Face transformer 模型) | |
音频生成 | 微调 - 音频生成(transformer 模型) |
Hello 简单模型
python1# main.py 2# ! pip install torchvision 3import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F 4import lightning as L 5 6# -------------------------------- 7# 步骤1: 定义一个 LightningModule 8# -------------------------------- 9# 一个 LightningModule(nn.Module 的子类)定义了一个完整的*系统* 10# (例如:一个 LLM、扩散模型、自编码器,或一个简单的图像分类器)。 11 12 13class LitAutoEncoder(L.LightningModule): 14 def __init__( 15 16self): 17 super().__init__() 18 self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) 19 self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) 20 21 def forward(self, x): 22 # 在 lightning 中,forward 定义了预测/推理行为 23 embedding = self.encoder(x) 24 return embedding 25 26 def training_step(self, batch, batch_idx): 27 # training_step 定义了训练循环。它独立于 forward 28 x, _ = batch 29 x = x.view(x.size(0), -1) 30 z = self.encoder(x) 31 x_hat = self.decoder(z) 32 loss = F.mse_loss(x_hat, x) 33 self.log("train_loss", loss) 34 return loss 35 36 def configure_optimizers(self): 37 optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 38 return optimizer 39 40 41# ------------------- 42# 步骤2: 定义数据 43# ------------------- 44dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) 45train, val = data.random_split(dataset, [55000, 5000]) 46 47# ------------------- 48# 步骤3: 训练 49# ------------------- 50autoencoder = LitAutoEncoder() 51trainer = L.Trainer() 52trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))
在终端运行模型
bash1pip install torchvision 2python main.py
高级功能
Lightning 具有超过40+高级功能,专为大规模专业AI研究设计。
以下是一些示例:
在成千上万的GPU上训练而无需更改代码
python1# 8 个 GPU 2# 无需代码更改 3trainer = Trainer(accelerator="gpu", devices=8) 4 5# 256 个 GPU 6trainer = Trainer(accelerator="gpu", devices=8, num_nodes=32)
在其他加速器(如 TPU)上训练而无需更改代码
python1# 无需代码更改 2trainer = Trainer(accelerator="tpu", devices=8)
16位精度
python1# 无需代码更改 2trainer = Trainer(precision=16)
实验管理器
python1from lightning import loggers 2 3# tensorboard 4trainer = Trainer(logger=TensorBoardLogger("logs/")) 5 6# weights and biases 7trainer = Trainer(logger=loggers.WandbLogger()) 8 9# comet 10trainer = Trainer(logger=loggers.CometLogger()) 11 12# mlflow 13trainer = Trainer(logger=loggers.MLFlowLogger()) 14 15# neptune 16trainer = Trainer(logger=loggers.NeptuneLogger()) 17 18# 还有很多
早停
python1es = EarlyStopping(monitor="val_loss") 2trainer = Trainer(callbacks=[es])
检查点
python1checkpointing = ModelCheckpoint(monitor="val_loss") 2trainer = Trainer(callbacks=[checkpointing])
导出到 torchscript (JIT)(生产用途)
python1# torchscript 2autoencoder = LitAutoEncoder() 3torch.jit.save(autoencoder.to_torchscript(), "model.pt")
导出到 ONNX(生产用途)
python1# onnx 2with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile: 3 autoencoder = LitAutoEncoder() 4 input_sample = torch.randn((1, 64)) 5 autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True) 6 os.path.isfile(tmpfile.name)
相比未结构化的 PyTorch 的优势
- 模型变得与硬件无关
- 代码清晰易读,因为工程代码被抽象掉了
- 更容易复现
- 减少错误,因为 Lightning 处理了复杂的工程问题
- 保留所有灵活性(LightningModules 仍然是 PyTorch 模块),但去除了大量的样板代码
- Lightning 与流行的机器学习工具有数十种集成。
- 每个新的 PR 都经过严格测试。我们测试每个支持的 PyTorch 和 Python 版本组合,每个操作系统,多 GPU 甚至 TPU。
- 运行速度开销最小(每个 epoch 仅比纯 PyTorch 慢约 300 毫秒)。
Lightning Fabric: 专家控制
在任何设备上、任何规模下运行,拥有对 PyTorch 训练循环和扩展策略的专家级控制。你甚至可以编写自己的 Trainer。
Fabric 专为最复杂的模型设计,如基础模型扩展、LLM、扩散模型、Transformer、强化学习、主动学习等。适用于任何规模。
更改内容 | 结果 Fabric 代码(复制我!) |
---|---|
|
|
主要特性
轻松切换从 CPU 运行到 GPU(Apple Silicon、CUDA、…)、TPU、多 GPU 甚至多节点训练
python1# 使用你的可用硬件 2# 无需代码更改 3fabric = Fabric() 4 5# 在 GPU(CUDA 或 MPS)上运行 6fabric = Fabric(accelerator="gpu") 7 8# 8 个 GPU 9fabric = Fabric(accelerator="gpu", devices=8) 10 11# 256 个 GPU,多节点 12fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32) 13 14# 在 TPU 上运行 15fabric = Fabric(accelerator="tpu")
开箱即用的最先进的分布式训练策略(DDP、FSDP、DeepSpeed)和混合精度
python1# 使用最先进的分布式训练技术 2fabric = Fabric(strategy="ddp") 3fabric = Fabric(strategy="deepspeed") 4fabric = Fabric(strategy="fsdp") 5 6# 切换精度 7fabric = Fabric(precision="16-mixed") 8fabric = Fabric(precision="64")
所有设备逻辑样板代码都为你处理
diff1 # 不再需要这些! 2- model.to(device) 3- batch.to(device)
使用 Fabric 原语构建你自己的自定义 Trainer 进行训练检查点记录、日志记录等
python
import lightning as L
class MyCustomTrainer:
def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)
def fit(self, model, optimizer, dataloader, max_epochs):
self.fabric.launch()
model, optimizer = self.fabric.setup(model, optimizer)
dataloader = self.fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(max_epochs):
for batch in dataloader:
input, target = batch
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
self.fabric.backward(loss)
optimizer.step()
你可以在我们的示例中找到更详细的示例
示例
自监督学习
卷积架构
强化学习
GANs
经典机器学习
持续集成
Lightning 在多个 CPU、GPU 和 TPU 以及主要的 Python 和 PyTorch 版本上进行了严格测试。
*Codecov 超过 90%+,但构建延迟可能显示更少
当前构建状态
系统 / PyTorch 版本 | 1.13 | 2.0 | 2.1 |
---|---|---|---|
Linux py3.9 [GPUs] | |||
Linux py3.9 [TPUs] | |||
Linux (多个 Python 版本) | |||
OSX (多个 Python 版本) | |||
Windows (多个 Python 版本) |
社区
Lightning 社区由以下人员维护
- 10 多名核心贡献者,他们都是来自顶级 AI 实验室的专业工程师、研究科学家和博士生的混合体。
- 800 多名社区贡献者。
想帮助我们构建 Lightning 并减少数千名研究人员的样板代码吗?了解如何在此处进行首次贡献
Lightning 也是PyTorch 生态系统的一部分,这要求项目必须有坚实的测试、文档和支持。
请求帮助
如果你有任何问题,请: