Skip to content

Latest commit

 

History

History
183 lines (118 loc) · 4.37 KB

lightning-transformers.md

File metadata and controls

183 lines (118 loc) · 4.37 KB

lightning-transformers

Pytorch-Lightning

LightningDataModule

LightningDataModule定义了5个api:

  • prepare_data (how to download(), tokenize, etc…)
  • setup (how to split, etc…)
  • train_dataloader
  • val_dataloader(s)
  • test_dataloader(s)

prepare_data

该函数负责预处理数据,包括下载,tokenize等。

注意:prepare_data is called from a single process (e.g. GPU 0). Do not use it to assign state (self.x = y).

setup

如果希望对数据的操作能在每个GPU上生效,可以通过setup实现一下逻辑:

  • 统计分类数
  • 构建词典
  • 数据分割(train/val/test)
  • 数据转换
  • ...

例如:

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def setup(self, stage: Optional[str] = None):

        # Assign Train/val split(s) for use in Dataloaders
        if stage in (None, "fit"):
            mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
            self.dims = self.mnist_train[0][0].shape

        # Assign Test split(s) for use in Dataloaders
        if stage in (None, "test"):
            self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
            self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)

stagesetup可选参数。为trainer实现分割逻辑trainer.{fit,validate,test}。如果stage=None,包含fit/validate/test全部逻辑。

train_dataloader/val_dataloader/test_dataloader

train_dataloader/val_dataloader/test_dataloader封装并返回setup中分割的数据。例如:

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)
        
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)
        
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

predict_dataloader

返回用于推理的数据集,用于trainer.predict方法使用。 如下所示:

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

LightningModule

LightningModule其实也是torch.nn.Module,包括5部分:

  • init
  • training_step
  • validation_step
  • test_step
  • configure_optimizers

更详细的用法参考LightningModule

Datasets of transformers

datasets后端采用Apache Arrow格式,极大的提高了数据的处理速度。以下主要从datasets加载数据、处理等两方面做介绍。其他功能参考datasets文档

datasets数据加载

数据加载支持以下加载方式:

  • The Hub
  • Local files
  • In-memory data
  • Offline
  • A specific slice of a split

加载本地及远程文件

支持加载csv/json/txt/Parquet等文件

from datasets import load_dataset

file_type = "csv" # json/text/parquet
data_files = {"train": train_file_path}
# data_files = train_file_path
dataset = load_dataset(file_type, data_files=data_files)

从内存加载

从dict对象加载

from datasets import Dataset

my_dict = {"a": [1, 2, 3]}

dataset = Dataset.from_dict(my_dict)

从pandas Dataframe加载

import pandas as pd
from datasets import Dataset


df = pd.DataFrame({"a": [1, 2, 3]})

dataset = Dataset.from_pandas(df)

datasets数据处理

主要包括

  • 排序和分割
  • 删除、重命名columns及其他列操作
  • 转换函数
  • datasets合并
  • 自定义格式转换
  • 保存或导出数据

排序和分割

sort 按列排序:

sorted_dataset = dataset.sort('label')

Shuffle

随机打乱:

shuffled_dataset = sorted_dataset.shuffle(seed=42)

Select and Filter

datasets使用两个函数用于过滤数据:datasets.Dataset.select() and datasets.Dataset.filter().

Split

分割训练集和测试集数据,使用datasets.Dataset.train_test_split(),默认shuffle=True。该方法类似sklearn的train_test_split,但是没有stratified选项。