-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathconfiguration.py
51 lines (44 loc) · 2.86 KB
/
configuration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class generator_config(object):
"""Wrapper class for generator hyperparameter"""
def __init__(self):
self.emb_dim = 32 #dimension of embedding
self.num_emb = 5000 #dimension of output unit
self.hidden_dim = 32 #dimension of hidden unit
self.sequence_length = 20 #maximum input sequence length
self.gen_batch_size = 64 #batch size of generator
self.start_token = 0 #special token for start of sentence
class discriminator_config(object):
"""Wrapper class for discriminator hyperparameter"""
def __init__(self):
self.sequence_length = 20 #maximum input sequence length
self.num_classes = 2 #number of class (real and fake)
self.vocab_size = 5000 #vocabulary size, shoud be same as num_emb
self.dis_embedding_dim = 64 #dimension of discriminator embedding space
self.dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] #convolutional kernel size of discriminator
self.dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] #number of filters of each conv. kernel
self.dis_dropout_keep_prob = 0.75 # dropout rate of discriminator
self.dis_l2_reg_lambda = 0.2 #L2 regularization strength
self.dis_batch_size = 64 #Batch size for discriminator
self.dis_learning_rate = 1e-4 #Learning rate of discriminator
class training_config(object):
"""Wrapper class for parameters for training"""
def __init__(self):
self.gen_learning_rate = 0.01 #learning rate of generator
self.gen_update_time = 1 #update times of generator in adversarial training
self.dis_update_time_adv = 5 #update times of discriminator in adversarial training
self.dis_update_epoch_adv = 3 #update epoch / times of discriminator
self.dis_update_time_pre = 50 #pretraining times of discriminator
self.dis_update_epoch_pre = 3 #number of epoch / time in pretraining
self.pretrained_epoch_num = 120 #Number of pretraining epoch
self.rollout_num = 16 #Rollout number for reward estimation
self.test_per_epoch = 5 #Test the NLL per epoch
self.batch_size = 64 #Batch size used for training
self.save_pretrained = 120 # Whether to save model in certain epoch (optional)
self.grad_clip = 5.0 #Gradient Clipping
self.seed = 88 #Random seed used for initialization
self.start_token = 0 #special start token
self.total_batch = 200 #total batch used for adversarial training
self.positive_file = "save/real_data.txt" # save path of real data generated by target LSTM
self.negative_file = "save/generator_sample.txt" #save path of fake data generated by generator
self.eval_file = "save/eval_file.txt" #file used for evaluation
self.generated_num = 10000 #Number of samples from generator used for evaluation