Skip to content

A general representation model across vision, audio, language modalities. Paper: ONE-PEACE: Exploring One General Representation Model Toward Unlimited Modalities

License

Notifications You must be signed in to change notification settings

OFA-Sys/ONE-PEACE

Repository files navigation



📖 Paper   |  🤗 Demo   |   🤖 ModelScope   |   Checkpoints  |  Datasets


ONE-PEACE is a general representation model across vision, audio, and language modalities, Without using any vision or language pretrained model for initialization, ONE-PEACE achieves leading results in vision, audio, audio-language, and vision-language tasks. Furthermore, ONE-PEACE possesses a strong emergent zero-shot retrieval capability, enabling it to align modalities that are not paired in the training data.

Below shows the architecture and pretraining tasks of ONE-PEACE. With the scaling-friendly architecture and modality-agnostic tasks, ONE-PEACE has the potential to expand to unlimited modalities.


Online Demo

We provide the online demo in Huggingface Spaces. In this demo, you can combine multiple modalities to retrieve related images, such as audio-to-image, audio+text-to-image, audio+image-to-image, and even audio+image+text-to-image.


News

  • 2023.7.20: Released the visual grounding API, you can use it to locate objects from the picture.
  • 2023.6.23: Released vision tasks fine-tuning scripts and checkpoints. See guidance for vision tasks for more details.
  • 2023.6.04: Released the pretraining scripts. See guidance for pretraining for more details.
  • 2023.5.30: Released the finetuned checkpoints and scripts for audio(-language) tasks.
  • 2023.5.29: Released the finetuned checkpoints for vision-language tasks.
  • 2023.5.27: 🔥 We have provided the multimodal retrieval demo in huggingface spaces. Have Fun!
  • 2023.5.25: Released the multimodal embedding API, which enables the quick extraction for image, audio and text representations.
  • 2023.5.23: Released the pretrained checkpoint, as well as finetuning & inference scripts for vision-language tasks.
  • 2023.5.19: Released the paper and code. Pretrained & finetuned checkpoints, training & inference scripts, as well as demos will be released as soon as possible.

Models and Results

Model Card

We list the parameters and pretrained checkpoints of ONE-PEACE below. Note that ONE-PEACE can be disassembled into different branches to handle different tasks. We also provide the vision-branch of ONE-PEACE, which can be used to perform vision tasks.

ModelCkptParamsHidden sizeIntermediate sizeAttention headsLayers
ONE-PEACEDownload4B153661442440
ONE-PEACE
(Vision Branch)
Download1.5B153661442440

Results

Vision Tasks

TaskImage classificationSemantic SegmentationObject Detection (w/o Object365)Video Action Recognition
DatasetImagenet-1KADE20KCOCOKinetics 400
Splitvalvalvalval
MetricAcc.mIoUss / mIoUmsAPbox / APmaskTop-1 Acc. / Top-5 Acc.
ONE-PEACE89.862.0 / 63.060.4 / 52.988.1 / 97.8

Audio Tasks

TaskAudio-Text RetrievalAudio ClassificationAudio Question Answering
DatasetAudioCapsClothoESC-50FSD50KVGGSound (Audio-Visual)AVQA
Splittestevaluationfullevaltestval
MetricT2A R@1A2T R@1T2A R@1A2T R@1Zero-shot Acc.MAPAcc.Acc.
ONE-PEACE42.551.022.427.191.869.768.292.2

Vision-Language Tasks

TaskImage-Text Retrieval (w/o ranking)Visual GroundingVQAVisual Reasoning
DatasetCOCOFlickr30KRefCOCORefCOCO+RefCOCOgVQAv2NLVR2
Splittesttestval / testA / testBval / testA / testBval-u / test-utest-dev / test-stddev / test-P
MetricI2T R@1T2I R@1I2T R@1T2I R@1Acc@0.5Acc.Acc.
ONE-PEACE84.165.497.689.692.58 / 94.18 / 89.2688.77 / 92.21 / 83.2389.22 / 89.2782.6 / 82.587.8 / 88.3


Requirements and Installation

  • 3.6 <= Python <=3.10
  • Pytorch >= 1.10.0 (recommend 1.13.1)
  • CUDA Version >= 10.2 (recommend 11.6)
  • Install required packages:
git clone https://github.com/OFA-Sys/ONE-PEACE
cd ONE-PEACE
pip install -r requirements.txt
  • For faster training install Apex library (optional):
git clone https://github.com/NVIDIA/apex
cd apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./
  • Install Xformers library to use Memory-efficient attention (optional):
conda install xformers -c xformers
git clone --recursive https://github.com/HazyResearch/flash-attention
cd flash-attention && pip install .
cd csrc/layer_norm && pip install .

Datasets and Checkpoints

See datasets.md and checkpoints.md.

Usage

API

We provide a simple code snippet to show how to use the API for ONE-PEACE.

Multi-modal Embedding

We use ONE-PEACE to compute embeddings for text, images, and audio, as well as their similarities:

import torch
from one_peace.models import from_pretrained

device = "cuda" if torch.cuda.is_available() else "cpu"
# "ONE-PEACE" can also be replaced with ckpt path
model = from_pretrained("ONE-PEACE", device=device, dtype="float32")

# process raw data
src_tokens = model.process_text(["cow", "dog", "elephant"])
src_images = model.process_image(["assets/dog.JPEG", "assets/elephant.JPEG"])
src_audios, audio_padding_masks = model.process_audio(["assets/cow.flac", "assets/dog.flac"])

with torch.no_grad():
    # extract normalized features
    text_features = model.extract_text_features(src_tokens)
    image_features = model.extract_image_features(src_images)
    audio_features = model.extract_audio_features(src_audios, audio_padding_masks)

    # compute similarity
    i2t_similarity = image_features @ text_features.T
    a2t_similarity = audio_features @ text_features.T

print("Image-to-text similarities:", i2t_similarity)
print("Audio-to-text similarities:", a2t_similarity)

Visual Grounding

We use ONE-PEACE to perform visual grounding on anime pictures:

import torch
import cv2
from one_peace.models import from_pretrained

device = "cuda" if torch.cuda.is_available() else "cpu"
model = from_pretrained(
	"ONE-PEACE_Grounding",
    model_type="one_peace_classify",
    device=device,
    dtype="float32"
)

# process raw data
image_text_list = [
    ("assets/pokemons.jpg", "a blue turtle-like pokemon with round head"),
    ("assets/pokemons.jpg", "Bulbasaur"),
    ("assets/pokemons.jpg", "Charmander"),
    ("assets/pokemons.jpg", "Squirtle"),
    ("assets/one_piece.jpeg", "Brook"),
    ("assets/one_piece.jpeg", "Franky"),
    ("assets/one_piece.jpeg", "Monkey D. Luffy"),
    ("assets/one_piece.jpeg", "Nami"),
    ("assets/one_piece.jpeg", "Nico Robin"),
    ("assets/one_piece.jpeg", "Roronoa Zoro"),
    ("assets/one_piece.jpeg", "Tony Tony Chopper"),
    ("assets/one_piece.jpeg", "Usopp"),
    ("assets/one_piece.jpeg", "Vinsmoke Sanji"),
]
(src_images, image_widths, image_heights), src_tokens  = model.process_image_text_pairs(
    image_text_list, return_image_sizes=True
)

with torch.no_grad():
    # extract features
    vl_features = model.extract_vl_features(src_images, src_tokens).sigmoid()
    # extract coords
    vl_features[:, ::2] *= image_widths.unsqueeze(1)
    vl_features[:, 1::2] *= image_heights.unsqueeze(1)
    coords = vl_features.cpu().tolist()

# display results
for i, image_text_pair in enumerate(image_text_list):
    image, text = image_text_pair
    img = cv2.imread(image)
    cv2.rectangle(
        img,
        (int(coords[i][0]), int(coords[i][1])),
        (int(coords[i][2]), int(coords[i][3])),
        (0, 255, 0),
        3
    )
    cv2.imshow(text, img)
    cv2.waitKey(3500)
    cv2.destroyAllWindows()

Audio Classification

We use ONE-PEACE to perform audio classification:

import torch
import json
from one_peace.models import from_pretrained

id2label = json.load(open("assets/vggsound_id2label.json"))

device = "cuda" if torch.cuda.is_available() else "cpu"
model = from_pretrained(
  "ONE-PEACE_VGGSound",
    model_type="one_peace_classify",
    device=device,
    dtype="float32"
)

# process audio
audio_list = ["assets/cow.flac", "assets/dog.flac"]
src_audios, audio_padding_masks = model.process_audio(audio_list)

with torch.no_grad():
    # extract audio features
    audio_logits = model.extract_audio_features(src_audios, audio_padding_masks)
    print(audio_logits.size())
    predict_label_ids = audio_logits.argmax(1).cpu().tolist()

for audio, predict_label_id in zip(audio_list, predict_label_ids):
    predict_label = id2label[str(predict_label_id)]
    print('audio: {}, predict label: {}'.format(audio, predict_label))

Training & Inference

If you are not satisfied with only using the API, we offer comprehensive training and inference instructions for audio & multimodal and vision tasks.



Gallery

Visual Grounding (unseen domain)

grounding

Emergent Zero-shot Retrieval

a2i

a+t2i

a+i2i

Acknowledgement

  • Fairseq A sequence modeling toolkit with flexible configuration and highly extensible code structure.
  • xFormers A toolbox to accelerate research on Transformers.
  • FlashAttention A repository that provides the official implementation of FlashAttention, which greatly speeds up multi-head attention.
  • Apex A repository that provides useful model acceleration and memory optimization techniques.

Getting Involved

Feel free to submit GitHub issues or pull requests. Welcome to contribute to our project!

To contact us, never hestitate to send an email to zheluo.wp@alibaba-inc.com or saimeng.wsj@alibaba-inc.com!

Citation

If you find our paper and code useful in your research, please consider giving a star ⭐ and citation 📝 :)

@article{wang2023one,
  title={ONE-PEACE: Exploring One General Representation Model Toward Unlimited Modalities},
  author={Wang, Peng and Wang, Shijie and Lin, Junyang and Bai, Shuai and Zhou, Xiaohuan and Zhou, Jingren and Wang, Xinggang and Zhou, Chang},
  journal={arXiv preprint arXiv:2305.11172},
  year={2023}
}