-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
179 lines (136 loc) · 6.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Basic utilities for configuration, file management, caching, and logging."""
import torch.cuda.memory
import yaml
import logging
import os
import sys
import datetime
import pickle
import bz2
from pathlib import Path
import urllib.request
import pandas as pd
# CONFIGURATION
def get_config(path: str):
"""Return the defined configuration value from 'config.yaml' of `path`."""
try:
value = __CONFIG__
for part in path.split('.'):
value = value[part]
except KeyError:
raise KeyError('Could not find configuration value for path "{}"!'.format(path))
return value
def set_config(path: str, val):
"""Override the configuration value at `path`."""
try:
current = __CONFIG__
path_segments = path.split('.')
for s in path_segments[:-1]:
current = current[s]
current[path_segments[-1]] = val
except KeyError:
raise KeyError('Could not find configuration value for path "{}"!'.format(path))
with open('config.yaml', 'r') as config_file:
__CONFIG__ = yaml.full_load(config_file)
# FILES
def get_data_file(config_path: str) -> str:
"""Return the file path where the data file is located (given the `config_path` where it is specified)."""
config = get_config(config_path)
# download file if not existing
filepath = os.path.join(_get_root_path(), 'data', config['filename'])
if not os.path.isfile(filepath):
url = config['url']
get_logger().info(f'Download file from {url}..')
urllib.request.urlretrieve(url, filepath)
get_logger().info(f'Finished downloading from {url}')
return filepath
def get_results_file(config_path: str) -> str:
"""Return the file path where the results file is located (given the `config_path` where it is specified)."""
return os.path.join(_get_root_path(), 'results', get_config(config_path))
# CACHING
def load_or_create_cache(cache_identifier: str, init_func, version=None):
"""Return the object cached at `cache_identifier` if existing - otherwise initialise it first using `init_func`."""
cache_obj = load_cache(cache_identifier, version=version)
if cache_obj is None:
cache_obj = init_func()
update_cache(cache_identifier, cache_obj, version=version)
return cache_obj
def load_cache(cache_identifier: str, version=None):
"""Return the object cached at `cache_identifier`."""
config = get_config('cache.{}'.format(cache_identifier))
cache_path = _get_cache_path(cache_identifier, version=version)
if not cache_path.exists():
return None
if _should_store_as_hdf(config):
return pd.read_hdf(cache_path, key='df')
if _should_store_as_csv(config):
return pd.read_csv(cache_path, sep=';', index_col=0)
else:
open_func = _get_cache_open_func(config)
with open_func(cache_path, mode='rb') as cache_file:
return pickle.load(cache_file)
def update_cache(cache_identifier: str, cache_obj, version=None):
"""Update the object cached at `cache_identifier` with `cache_obj`."""
config = get_config('cache.{}'.format(cache_identifier))
cache_path = _get_cache_path(cache_identifier, version=version)
if _should_store_as_hdf(config):
cache_obj.to_hdf(cache_path, key='df', mode='w')
elif _should_store_as_csv(config):
cache_obj.to_csv(cache_path, sep=';')
else:
open_func = _get_cache_open_func(config)
with open_func(cache_path, mode='wb') as cache_file:
pickle.dump(cache_obj, cache_file, protocol=pickle.HIGHEST_PROTOCOL)
def _get_cache_path(cache_identifier: str, version=None) -> Path:
config = get_config('cache.{}'.format(cache_identifier))
filename = config['filename']
version = version or config['version']
if _should_store_as_folder(config):
base_fileformat = ''
elif _should_store_as_hdf(config):
base_fileformat = '.h5'
elif _should_store_as_csv(config):
base_fileformat = '.csv'
else:
base_fileformat = '.p'
fileformat = base_fileformat + ('.bz2' if _should_compress_cache(config) else '')
return Path(os.path.join(_get_root_path(), 'data', 'cache', f'{filename}_v{version}{fileformat}'))
def _get_cache_open_func(config: dict):
return bz2.open if _should_compress_cache(config) else open
def _should_compress_cache(config: dict) -> bool:
return 'compress' in config and config['compress']
def _should_store_as_folder(config: dict) -> bool:
return 'store_as_folder' in config and config['store_as_folder']
def _should_store_as_hdf(config: dict) -> bool:
return 'store_as_hdf' in config and config['store_as_hdf']
def _should_store_as_csv(config: dict) -> bool:
return 'store_as_csv' in config and config['store_as_csv']
def _get_root_path() -> str:
return os.path.dirname(os.path.realpath(__file__))
# LOGGING
def get_logger():
return logging.getLogger('impl')
if get_config('logging.to_file') and 'ipykernel' not in sys.modules:
log_filename = '{}_{}.log'.format(datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), get_config('logging.filename'))
log_filepath = os.path.join(_get_root_path(), 'logs', log_filename)
log_handler = logging.FileHandler(log_filepath, 'a', 'utf-8')
else:
log_handler = logging.StreamHandler()
log_handler.setFormatter(logging.Formatter('%(asctime)s|%(levelname)s|%(module)s->%(funcName)s: %(message)s'))
log_handler.setLevel(get_config('logging.level'))
logger = get_logger()
logger.addHandler(log_handler)
logger.setLevel(get_config('logging.level'))
# GPU CACHE MANAGEMENT
__CACHE_MEMORY_POINTER__ = None
def reserve_gpu(memory_in_gb: int = 47):
global __CACHE_MEMORY_POINTER__
if __CACHE_MEMORY_POINTER__ is not None:
raise MemoryError('Tried to allocate memory but cache memory pointer is already set!')
__CACHE_MEMORY_POINTER__ = torch.cuda.memory.caching_allocator_alloc(1024 * 1024 * 1024 * memory_in_gb)
def release_gpu():
global __CACHE_MEMORY_POINTER__
if isinstance(__CACHE_MEMORY_POINTER__, int):
torch.cuda.memory.caching_allocator_delete(__CACHE_MEMORY_POINTER__)
torch.cuda.memory.empty_cache()
__CACHE_MEMORY_POINTER__ = None