For the Chinese version of the README, please refer to 中文文档.
- Data Preprocessing: The relevant code is located in the
dataprocess
folder, and dataset-related code is in thedataset
folder. Data preprocessing mainly includes path merging, QA data concatenation, feature insertion token processing, etc. - LLM Model: Uses Qwen-7B as the main model, with relevant code in the
qwen
folder. By overriding theforward
method ofQWenModel
, multimodal feature injection is achieved. - Visual Model: Uses
CLIP_VIT
andSIGLIP_VIT
, with relevant code in thevisual
folder, which also includes other backbone networks. - VLM Model: Relevant code is in the
model.py
file under themodel
folder.
We use a multilingual dataset, mainly including the COCO2017 dataset and the AI Challenger image Chinese description dataset:
- The COCO dataset annotations use LLAVA's
detail_23k
andcomplex_reasoning_77k
, which can effectively enhance the richness of the model's descriptions. - The AI Challenger dataset uses the original annotations and a fixed prompt.
In VLM, the visual part uses the CLIP
or SIGLIP
model, which has already achieved preliminary semantic alignment, and uses a two-layer MLP for feature mapping. By overriding the forward
method of QWenModel
, the corresponding image
tokens are replaced with visual features.
If you wish to replace the model architecture, please modify this part.
AI Challenger | COCO | complex_reasoning_77k.json | detail_23k.json |
---|---|---|---|
AI Challenger | COCO 2017 | complex_reasoning_77k.json | detail_23k.json |
Please store the datasets according to the paths in the configuration file. Of course, the paths can be customized.
Please note that this path needs to be consistent with data/ for the model to read.
After downloading the data, use process_image.py
for preprocessing.
Use pip install
to install requirements.txt
:
pip install -r requirements.txt
Model training adopts the method of freezing the image model, and LLM uses the LoRA method to reduce training pressure. The parameters to be trained include the visual feature mapping layer and the LoRA parameters in the LLM. Since the mapping layer is initialized with untrained parameters, to balance the optimization speed of the model parameters, a larger learning rate is set for the mapping layer than for the LoRA part.
Run the train.sh
in the root directory, and you can configure the relevant parameters for experiments.
sh train.sh
Through the above steps, you can start the training process and train the multimodal model.
The model weights will be saved in the --output_dir
, and this path can also be customized.
CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=25642 train.py \
--lora_rank 128 \
--lora_dropout 0.10 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--num_train_epochs 2 \
--save_steps 1000 \
--save_total_limit 5 \
--learning_rate 3e-5 \
--seed 42 \
--ddp_find_unused_parameters False \
--feature_proj_lr 1e-4 \
--remove_unused_columns false \
--logging_steps 100 \
--output_dir ./weights/train_V1_5 \
--target_modules "c_attn|w1|w2" \
--image_map /home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json \
--captions_file /home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json
- CUDA_VISIBLE_DEVICES=0: Use GPU with ID 0.
- torchrun: PyTorch's distributed training tool.
- --nproc_per_node=1: Run 1 process per node.
- --master_port=25642: Set the inter-process communication port.
- train.py: Main training script.
- --lora_rank 128: The rank of the LoRA layer is 128.
- --lora_dropout 0.10: The dropout rate of the LoRA layer is 10%.
- --per_device_train_batch_size 4: The training batch size per device is 4.
- --gradient_accumulation_steps 1: Gradient accumulation steps are 1.
- --num_train_epochs 2: Train for 2 epochs.
- --save_steps 1000: Save the model every 1000 steps.
- --save_total_limit 5: Save up to 5 checkpoints.
- --learning_rate 3e-5: Learning rate is 3e-5.
- --seed 42: Random seed is 42.
- --ddp_find_unused_parameters False: Disable DDP finding unused parameters.
- --feature_proj_lr 1e-4: Learning rate for the feature projection layer is 1e-4.
- --remove_unused_columns false: Retain unused columns.
- --logging_steps 100: Log every 100 steps.
- --output_dir ./weights/train_V1_5: Output directory.
- --target_modules "c_attn|w1|w2": Target modules for LoRA adaptation.
- --image_map /home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json: Path to the image mapping file.
- --captions_file /home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json: Path to the captions file.
Run the test.sh
in the root directory, and you can configure the relevant parameters for experiments.
sh test.sh
The code will read images from the folder for Q&A.
python test.py --base_language_model Qwen/Qwen-7B-Chat --base_value_model openai/clip-vit-large-patch14 --model_weights ./weights/train_V1_5/checkpoint-10000/ --image_path ./test_img/1.jpg --prompt "Describe the colors appearing in the image<|extra_0|>"
If you want to test the model directly, the pre-trained weights provided are as follows:
SIGLIP_Qwen_epoch19000 | SIGLIP_Qwen_epoch36000 |
---|---|
Model1 | Model2 |
You can directly download the relevant files and test them.
- --base_language_model Qwen/Qwen-7B-Chat: Specify the path to the base language model, here using
Qwen/Qwen-7B-Chat
. - --base_value_model openai/clip-vit-large-patch14: Specify the path to the base visual model, here using
openai/clip-vit-large-patch14
. - --model_weights ./weights/train_V1_5/checkpoint-10000/: Specify the path to the model weights, here using the checkpoint
checkpoint-10000
saved during training. - --image_path ./test_img/1.jpg: Specify the path to the input image, here using
./test_img/1.jpg
. - --prompt "Describe the colors appearing in the image<|extra_0|>": Specify the prompt for the model, here asking the model to describe the colors appearing in the image.
Thanks to the great work of the following projects 🙌:
- https://github.com/WatchTower-Liu/VLM-learning/tree/main
- https://github.com/QwenLM/Qwen
- https://github.com/haotian-liu/LLaVA
If you have any questions or ideas, feel free to contact me 😊:
I will reply as soon as I see the email!