-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
executable file
·63 lines (32 loc) · 984 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import json
from lib.core.base_trainer.net_work import Train
from lib.dataset.dataietr import AlaskaDataIter
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import pandas as pd
from train_config import config as cfg
import setproctitle
setproctitle.setproctitle("pks")
def get_fold(df,n_folds):
skf = KFold(n_splits=n_folds, shuffle=True, random_state=cfg.SEED)
for fold, (train_idx, val_idx) in enumerate(skf.split(df)):
df.loc[val_idx, 'fold'] = fold
return df
def setppm100asval(df):
def func(fn):
if 'PPM' in fn:
return 0
else:
return 1
df['fold']=df['image'].apply(func)
return df
def main():
train_df = pd.read_csv(cfg.DATA.train_f_path)
val_df =pd.read_csv(cfg.DATA.val_f_path)
###build trainer
trainer = Train(train_df=train_df,val_df=val_df,fold=0)
### train
trainer.custom_loop()
if __name__=='__main__':
main()