项目详情
深入了解 pytorch-lightning 的功能与特性
安装 Lightning
从 PyPI 简单安装
bash1pip install lightning
其他安装选项
安装带可选依赖项的版本
bash
Pretrain, finetune ANY AI model of ANY size on 1 or 10,000+ GPUs with zero code changes.
深入了解 pytorch-lightning 的功能与特性
从 PyPI 简单安装
bash1pip install lightning
bash
发现更多类似的优秀工具
1pip install lightning['extra']bash1conda install lightning -c conda-forge
从源码安装未来发布的稳定版本
bash1pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U
从源码安装 nightly 版本(不保证稳定性)
bash1pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
或从测试 PyPI 安装
bash1pip install -iU https://test.pypi.org/simple/ pytorch-lightning
Lightning 让你可以精细控制在 PyTorch 上增加的抽象层次。
PyTorch Lightning 只是更有组织的 PyTorch - Lightning 解耦了 PyTorch 代码,将科学与工程分离。

探索使用 PyTorch Lightning 进行各种类型的训练。预训练和微调任何类型的模型,以执行分类、分割、摘要等任务:
| 任务 | 描述 | 运行 |
|---|---|---|
| Hello world | 预训练 - Hello world 示例 | |
| 图像分割 | 微调 - 使用 ResNet-50 模型进行图像分割 | |
| 文本分类 | 微调 - 文本分类器(BERT 模型) | |
| 文本摘要 | 微调 - 文本摘要(Hugging Face transformer 模型) | |
| 音频生成 | 微调 - 音频生成(transformer 模型) |
python1# main.py 2# ! pip install torchvision 3import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F 4import lightning as L 5 6# -------------------------------- 7# 步骤1: 定义一个 LightningModule 8# -------------------------------- 9# 一个 LightningModule(nn.Module 的子类)定义了一个完整的*系统* 10# (例如:一个 LLM、扩散模型、自编码器,或一个简单的图像分类器)。 11 12 13class LitAutoEncoder(L.LightningModule): 14 def __init__( 15 16self): 17 super().__init__() 18 self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) 19 self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) 20 21 def forward(self, x): 22 # 在 lightning 中,forward 定义了预测/推理行为 23 embedding = self.encoder(x) 24 return embedding 25 26 def training_step(self, batch, batch_idx): 27 # training_step 定义了训练循环。它独立于 forward 28 x, _ = batch 29 x = x.view(x.size(0), -1) 30 z = self.encoder(x) 31 x_hat = self.decoder(z) 32 loss = F.mse_loss(x_hat, x) 33 self.log("train_loss", loss) 34 return loss 35 36 def configure_optimizers(self): 37 optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 38 return optimizer 39 40 41# ------------------- 42# 步骤2: 定义数据 43# ------------------- 44dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) 45train, val = data.random_split(dataset, [55000, 5000]) 46 47# ------------------- 48# 步骤3: 训练 49# ------------------- 50autoencoder = LitAutoEncoder() 51trainer = L.Trainer() 52trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))
在终端运行模型
bash1pip install torchvision 2python main.py
Lightning 具有超过40+高级功能,专为大规模专业AI研究设计。
以下是一些示例:
python1# 8 个 GPU 2# 无需代码更改 3trainer = Trainer(accelerator="gpu", devices=8) 4 5# 256 个 GPU 6trainer = Trainer(accelerator="gpu", devices=8, num_nodes=32)
python1# 无需代码更改 2trainer = Trainer(accelerator="tpu", devices=8)
python1# 无需代码更改 2trainer = Trainer(precision=16)
python1from lightning import loggers 2 3# tensorboard 4trainer = Trainer(logger=TensorBoardLogger("logs/")) 5 6# weights and biases 7trainer = Trainer(logger=loggers.WandbLogger()) 8 9# comet 10trainer = Trainer(logger=loggers.CometLogger()) 11 12# mlflow 13trainer = Trainer(logger=loggers.MLFlowLogger()) 14 15# neptune 16trainer = Trainer(logger=loggers.NeptuneLogger()) 17 18# 还有很多
python1es = EarlyStopping(monitor="val_loss") 2trainer = Trainer(callbacks=[es])
python1checkpointing = ModelCheckpoint(monitor="val_loss") 2trainer = Trainer(callbacks=[checkpointing])
python1# torchscript 2autoencoder = LitAutoEncoder() 3torch.jit.save(autoencoder.to_torchscript(), "model.pt")
python1# onnx 2with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile: 3 autoencoder = LitAutoEncoder() 4 input_sample = torch.randn((1, 64)) 5 autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True) 6 os.path.isfile(tmpfile.name)
在任何设备上、任何规模下运行,拥有对 PyTorch 训练循环和扩展策略的专家级控制。你甚至可以编写自己的 Trainer。
Fabric 专为最复杂的模型设计,如基础模型扩展、LLM、扩散模型、Transformer、强化学习、主动学习等。适用于任何规模。
| 更改内容 | 结果 Fabric 代码(复制我!) |
|---|---|
|
|
python1# 使用你的可用硬件 2# 无需代码更改 3fabric = Fabric() 4 5# 在 GPU(CUDA 或 MPS)上运行 6fabric = Fabric(accelerator="gpu") 7 8# 8 个 GPU 9fabric = Fabric(accelerator="gpu", devices=8) 10 11# 256 个 GPU,多节点 12fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32) 13 14# 在 TPU 上运行 15fabric = Fabric(accelerator="tpu")
python1# 使用最先进的分布式训练技术 2fabric = Fabric(strategy="ddp") 3fabric = Fabric(strategy="deepspeed") 4fabric = Fabric(strategy="fsdp") 5 6# 切换精度 7fabric = Fabric(precision="16-mixed") 8fabric = Fabric(precision="64")
diff1 # 不再需要这些! 2- model.to(device) 3- batch.to(device)
python
import lightning as L
class MyCustomTrainer:
def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)
def fit(self, model, optimizer, dataloader, max_epochs):
self.fabric.launch()
model, optimizer = self.fabric.setup(model, optimizer)
dataloader = self.fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(max_epochs):
for batch in dataloader:
input, target = batch
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
self.fabric.backward(loss)
optimizer.step()
你可以在我们的示例中找到更详细的示例
Lightning 在多个 CPU、GPU 和 TPU 以及主要的 Python 和 PyTorch 版本上进行了严格测试。
Lightning 社区由以下人员维护
想帮助我们构建 Lightning 并减少数千名研究人员的样板代码吗?了解如何在此处进行首次贡献
Lightning 也是PyTorch 生态系统的一部分,这要求项目必须有坚实的测试、文档和支持。
如果你有任何问题,请: