-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocessing.py
25 lines (22 loc) · 1.08 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
import os
import av
import torch
from transformers import VivitImageProcessor, VivitModel
from data_handling import frames_convert_and_create_dataset_dictionary
from datasets import Dataset
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
def process_example(example):
inputs = image_processor(list(np.array(example['video'])), return_tensors='pt')
inputs['labels'] = example['labels']
return inputs
video_dict= frames_convert_and_create_dataset_dictionary("file location")
def create_dataset(video_dictionary):
dataset = Dataset.from_list(video_dict)
dataset = dataset.class_encode_column("labels")
processed_dataset = dataset.map(process_example, batched=False)
processed_dataset=processed_dataset.remove_columns(['video'])
shuffled_dataset= processed_dataset.shuffle(seed=42)
shuffled_dataset= shuffled_dataset.map(lambda x: {'pixel_values': torch.tensor(x['pixel_values']).squeeze()})
shuffled_dataset =shuffled_dataset.train_test_split(test_size=0.1)
return shuffled_dataset