forked from Charleswyt/tf_audio_steganalysis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pre_process.py
83 lines (68 loc) · 2.6 KB
/
pre_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on 2018.01.03
Finished on 2018.01.03
@author: Wang Yuntao
"""
import math
import numpy as np
"""
The pre-processing of the QMDCT matrix:
truncate(matrix, threshold) 截断处理
down_sampling(matrix, mode, mode_number) 单一模式下采样
get_down_sampling(matrix, mode_number) 混合模式下采样
"""
def truncate(matrix, threshold, threshold_left=None, threshold_right=None):
"""
truncation (数据截断)
:param matrix: the input matrix (numpy.ndarray)
:param threshold: threshold
:param threshold_left: threshold (for minimum)
:param threshold_right: threshold (for maximum)
:return:
"""
if threshold_left is not None and threshold_right is not None:
matrix[matrix > threshold_left] = threshold_left
matrix[matrix > threshold_right] = threshold_right
else:
matrix[matrix > threshold] = threshold
matrix[matrix < threshold] = -threshold
return matrix
def down_sampling(matrix, mode, mode_number):
"""
the downsampling of the matrix (矩阵下采样)
:param matrix: the input matrix
:param mode: the current mode
:param mode_number: the total number of the modes
:return: down sampling matrix
"""
stride = int(math.sqrt(mode_number))
mask = list(range(mode_number))
mask = np.reshape(mask, [stride, stride])
index = np.argwhere(mask == mode)[0]
i, j = index[0], index[1]
output = matrix[i::stride, j::stride]
return output
def get_down_sampling(matrix, mode_number):
"""
the downsampling of the matrix (矩阵下采样)
:param matrix: the input matrix
:param mode_number: the total number of the modes
"""
shape = np.shape(matrix)
height, width = shape[0], shape[1]
matrix = np.reshape(matrix, [height, width])
stride = math.sqrt(mode_number)
sub_height, sub_width = int(height // stride), int(width // stride)
output = np.zeros([sub_height, sub_width, mode_number])
for i in range(mode_number):
output[:, :, i] = down_sampling(matrix, i, mode_number)
return output
if __name__ == "__main__":
from text_preprocess import read_text
file_path = "E:/Myself/2.database/10.QMDCT/1.txt/APS/128_01/wav10s_00689.txt"
QMDCT = read_text(file_path, is_abs=True, is_diff=True, order=2, direction=1, is_trunc=True, threshold=3)
print(np.shape(QMDCT))
hh = get_down_sampling(QMDCT, 4)
print(np.shape(hh))