-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored xin branch #1
base: xin
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ def _download_and_extract(url, path, filename): | |
os.makedirs(path, exist_ok=True) | ||
f_remote = requests.get(url, stream=True) | ||
sz = f_remote.headers.get('content-length') | ||
assert f_remote.status_code == 200, 'fail to open {}'.format(url) | ||
assert f_remote.status_code == 200, f'fail to open {url}' | ||
with open(fn, 'wb') as writer: | ||
for chunk in f_remote.iter_content(chunk_size=1024*1024): | ||
writer.write(chunk) | ||
|
@@ -65,8 +65,7 @@ def _parse_srd_format(format): | |
|
||
def _file_line(path): | ||
with open(path) as f: | ||
for i, l in enumerate(f): | ||
pass | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return i + 1 | ||
|
||
class KGDataset: | ||
|
@@ -117,7 +116,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]): | |
if path is None: | ||
return None | ||
|
||
print('Reading {} triples....'.format(mode)) | ||
print(f'Reading {mode} triples....') | ||
Comment on lines
-120
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
heads = [] | ||
tails = [] | ||
rels = [] | ||
|
@@ -134,7 +133,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]): | |
heads = np.array(heads, dtype=np.int64) | ||
tails = np.array(tails, dtype=np.int64) | ||
rels = np.array(rels, dtype=np.int64) | ||
print('Finished. Read {} {} triples.'.format(len(heads), mode)) | ||
print(f'Finished. Read {len(heads)} {mode} triples.') | ||
|
||
return (heads, rels, tails) | ||
|
||
|
@@ -164,7 +163,7 @@ def read_triple(self, path, mode): | |
heads = [] | ||
tails = [] | ||
rels = [] | ||
print('Reading {} triples....'.format(mode)) | ||
print(f'Reading {mode} triples....') | ||
Comment on lines
-167
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
with open(path) as f: | ||
for line in f: | ||
h, r, t = line.strip().split('\t') | ||
|
@@ -175,7 +174,7 @@ def read_triple(self, path, mode): | |
heads = np.array(heads, dtype=np.int64) | ||
tails = np.array(tails, dtype=np.int64) | ||
rels = np.array(rels, dtype=np.int64) | ||
print('Finished. Read {} {} triples.'.format(len(heads), mode)) | ||
print(f'Finished. Read {len(heads)} {mode} triples.') | ||
|
||
return (heads, rels, tails) | ||
|
||
|
@@ -195,11 +194,11 @@ class KGDatasetFB15k(KGDataset): | |
''' | ||
def __init__(self, path, name='FB15k'): | ||
self.name = name | ||
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) | ||
url = f'https://data.dgl.ai/dataset/{name}.zip' | ||
|
||
if not os.path.exists(os.path.join(path, name)): | ||
print('File not found. Downloading from', url) | ||
_download_and_extract(url, path, name + '.zip') | ||
_download_and_extract(url, path, f'{name}.zip') | ||
Comment on lines
-198
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.path = os.path.join(path, name) | ||
|
||
super(KGDatasetFB15k, self).__init__(os.path.join(self.path, 'entities.dict'), | ||
|
@@ -224,11 +223,11 @@ class KGDatasetFB15k237(KGDataset): | |
''' | ||
def __init__(self, path, name='FB15k-237'): | ||
self.name = name | ||
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) | ||
url = f'https://data.dgl.ai/dataset/{name}.zip' | ||
|
||
if not os.path.exists(os.path.join(path, name)): | ||
print('File not found. Downloading from', url) | ||
_download_and_extract(url, path, name + '.zip') | ||
_download_and_extract(url, path, f'{name}.zip') | ||
Comment on lines
-227
to
+230
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.path = os.path.join(path, name) | ||
|
||
super(KGDatasetFB15k237, self).__init__(os.path.join(self.path, 'entities.dict'), | ||
|
@@ -253,11 +252,11 @@ class KGDatasetWN18(KGDataset): | |
''' | ||
def __init__(self, path, name='wn18'): | ||
self.name = name | ||
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) | ||
url = f'https://data.dgl.ai/dataset/{name}.zip' | ||
|
||
if not os.path.exists(os.path.join(path, name)): | ||
print('File not found. Downloading from', url) | ||
_download_and_extract(url, path, name + '.zip') | ||
_download_and_extract(url, path, f'{name}.zip') | ||
Comment on lines
-256
to
+259
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.path = os.path.join(path, name) | ||
|
||
super(KGDatasetWN18, self).__init__(os.path.join(self.path, 'entities.dict'), | ||
|
@@ -282,11 +281,11 @@ class KGDatasetWN18rr(KGDataset): | |
''' | ||
def __init__(self, path, name='wn18rr'): | ||
self.name = name | ||
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) | ||
url = f'https://data.dgl.ai/dataset/{name}.zip' | ||
|
||
if not os.path.exists(os.path.join(path, name)): | ||
print('File not found. Downloading from', url) | ||
_download_and_extract(url, path, name + '.zip') | ||
_download_and_extract(url, path, f'{name}.zip') | ||
Comment on lines
-285
to
+288
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.path = os.path.join(path, name) | ||
|
||
super(KGDatasetWN18rr, self).__init__(os.path.join(self.path, 'entities.dict'), | ||
|
@@ -310,11 +309,11 @@ class KGDatasetFreebase(KGDataset): | |
''' | ||
def __init__(self, path, name='Freebase'): | ||
self.name = name | ||
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name) | ||
url = f'https://data.dgl.ai/dataset/{name}.zip' | ||
|
||
if not os.path.exists(os.path.join(path, name)): | ||
print('File not found. Downloading from', url) | ||
_download_and_extract(url, path, '{}.zip'.format(name)) | ||
_download_and_extract(url, path, f'{name}.zip') | ||
Comment on lines
-313
to
+316
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.path = os.path.join(path, name) | ||
|
||
super(KGDatasetFreebase, self).__init__(os.path.join(self.path, 'entity2id.txt'), | ||
|
@@ -337,7 +336,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=None): | |
heads = [] | ||
tails = [] | ||
rels = [] | ||
print('Reading {} triples....'.format(mode)) | ||
print(f'Reading {mode} triples....') | ||
Comment on lines
-340
to
+339
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
with open(path) as f: | ||
if skip_first_line: | ||
_ = f.readline() | ||
|
@@ -350,7 +349,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=None): | |
heads = np.array(heads, dtype=np.int64) | ||
tails = np.array(tails, dtype=np.int64) | ||
rels = np.array(rels, dtype=np.int64) | ||
print('Finished. Read {} {} triples.'.format(len(heads), mode)) | ||
print(f'Finished. Read {len(heads)} {mode} triples.') | ||
return (heads, rels, tails) | ||
|
||
class KGDatasetUDDRaw(KGDataset): | ||
|
@@ -369,8 +368,7 @@ class KGDatasetUDDRaw(KGDataset): | |
def __init__(self, path, name, files, format): | ||
self.name = name | ||
for f in files: | ||
assert os.path.exists(os.path.join(path, f)), \ | ||
'File {} now exist in {}'.format(f, path) | ||
assert os.path.exists(os.path.join(path, f)), f'File {f} now exist in {path}' | ||
Comment on lines
-372
to
+371
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
assert len(format) == 3 | ||
format = _parse_srd_format(format) | ||
|
@@ -437,8 +435,7 @@ class KGDatasetUDD(KGDataset): | |
def __init__(self, path, name, files, format): | ||
self.name = name | ||
for f in files: | ||
assert os.path.exists(os.path.join(path, f)), \ | ||
'File {} now exist in {}'.format(f, path) | ||
assert os.path.exists(os.path.join(path, f)), f'File {f} now exist in {path}' | ||
Comment on lines
-440
to
+438
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
format = _parse_srd_format(format) | ||
if len(files) == 3: | ||
|
@@ -458,22 +455,22 @@ def __init__(self, path, name, files, format): | |
def read_entity(self, entity_path): | ||
n_entities = 0 | ||
with open(entity_path) as f_ent: | ||
for line in f_ent: | ||
for _ in f_ent: | ||
Comment on lines
-461
to
+458
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
n_entities += 1 | ||
return None, n_entities | ||
|
||
def read_relation(self, relation_path): | ||
n_relations = 0 | ||
with open(relation_path) as f_rel: | ||
for line in f_rel: | ||
for _ in f_rel: | ||
Comment on lines
-468
to
+465
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
n_relations += 1 | ||
return None, n_relations | ||
|
||
def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]): | ||
heads = [] | ||
tails = [] | ||
rels = [] | ||
print('Reading {} triples....'.format(mode)) | ||
print(f'Reading {mode} triples....') | ||
Comment on lines
-476
to
+473
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
with open(path) as f: | ||
if skip_first_line: | ||
_ = f.readline() | ||
|
@@ -486,7 +483,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]): | |
heads = np.array(heads, dtype=np.int64) | ||
tails = np.array(tails, dtype=np.int64) | ||
rels = np.array(rels, dtype=np.int64) | ||
print('Finished. Read {} {} triples.'.format(len(heads), mode)) | ||
print(f'Finished. Read {len(heads)} {mode} triples.') | ||
return (heads, rels, tails) | ||
|
||
def get_dataset(data_path, data_name, format_str, files=None): | ||
|
@@ -502,7 +499,7 @@ def get_dataset(data_path, data_name, format_str, files=None): | |
elif data_name == 'wn18rr': | ||
dataset = KGDatasetWN18rr(data_path) | ||
else: | ||
assert False, "Unknown dataset {}".format(data_name) | ||
assert False, f"Unknown dataset {data_name}" | ||
Comment on lines
-505
to
+502
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
elif format_str.startswith('raw_udd'): | ||
# user defined dataset | ||
format = format_str[8:] | ||
|
@@ -512,13 +509,13 @@ def get_dataset(data_path, data_name, format_str, files=None): | |
format = format_str[4:] | ||
dataset = KGDatasetUDD(data_path, data_name, files, format) | ||
else: | ||
assert False, "Unknown format {}".format(format_str) | ||
assert False, f"Unknown format {format_str}" | ||
|
||
return dataset | ||
|
||
|
||
def get_partition_dataset(data_path, data_name, part_id): | ||
part_name = os.path.join(data_name, 'partition_'+str(part_id)) | ||
part_name = os.path.join(data_name, f'partition_{str(part_id)}') | ||
Comment on lines
-521
to
+518
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
path = os.path.join(data_path, part_name) | ||
|
||
if not os.path.exists(path): | ||
|
@@ -547,19 +544,15 @@ def get_partition_dataset(data_path, data_name, part_id): | |
|
||
partition_book = [] | ||
with open(partition_book_path) as f: | ||
for line in f: | ||
partition_book.append(int(line)) | ||
|
||
partition_book.extend(int(line) for line in f) | ||
local_to_global = [] | ||
with open(local2global_path) as f: | ||
for line in f: | ||
local_to_global.append(int(line)) | ||
|
||
local_to_global.extend(int(line) for line in f) | ||
return dataset, partition_book, local_to_global | ||
|
||
|
||
def get_server_partition_dataset(data_path, data_name, part_id): | ||
part_name = os.path.join(data_name, 'partition_'+str(part_id)) | ||
part_name = os.path.join(data_name, f'partition_{str(part_id)}') | ||
Comment on lines
-562
to
+555
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
path = os.path.join(data_path, part_name) | ||
|
||
if not os.path.exists(path): | ||
|
@@ -589,9 +582,7 @@ def get_server_partition_dataset(data_path, data_name, part_id): | |
|
||
local_to_global = [] | ||
with open(local2global_path) as f: | ||
for line in f: | ||
local_to_global.append(int(line)) | ||
|
||
local_to_global.extend(int(line) for line in f) | ||
global_to_local = [0] * n_entities | ||
for i in range(len(local_to_global)): | ||
global_id = local_to_global[i] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,7 @@ def SoftRelationPartition(edges, n, threshold=0.05): | |
Whether there exists some relations belongs to multiple partitions | ||
""" | ||
heads, rels, tails = edges | ||
print('relation partition {} edges into {} parts'.format(len(heads), n)) | ||
print(f'relation partition {len(heads)} edges into {n} parts') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
uniq, cnts = np.unique(rels, return_counts=True) | ||
idx = np.flip(np.argsort(cnts)) | ||
cnts = cnts[idx] | ||
|
@@ -73,16 +73,12 @@ def SoftRelationPartition(edges, n, threshold=0.05): | |
edge_cnts = np.zeros(shape=(n,), dtype=np.int64) | ||
rel_cnts = np.zeros(shape=(n,), dtype=np.int64) | ||
rel_dict = {} | ||
rel_parts = [] | ||
cross_rel_part = [] | ||
for _ in range(n): | ||
rel_parts.append([]) | ||
|
||
rel_parts = [[] for _ in range(n)] | ||
large_threshold = int(len(rels) * threshold) | ||
capacity_per_partition = int(len(rels) / n) | ||
# ensure any relation larger than the partition capacity will be split | ||
large_threshold = capacity_per_partition if capacity_per_partition < large_threshold \ | ||
else large_threshold | ||
large_threshold = min(capacity_per_partition, large_threshold) | ||
num_cross_part = 0 | ||
for i in range(len(cnts)): | ||
cnt = cnts[i] | ||
|
@@ -108,8 +104,8 @@ def SoftRelationPartition(edges, n, threshold=0.05): | |
rel_dict[r] = r_parts | ||
|
||
for i, edge_cnt in enumerate(edge_cnts): | ||
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i])) | ||
print('{}/{} duplicated relation across partitions'.format(num_cross_part, len(cnts))) | ||
print(f'part {i} has {edge_cnt} edges and {rel_cnts[i]} relations') | ||
print(f'{num_cross_part}/{len(cnts)} duplicated relation across partitions') | ||
|
||
parts = [] | ||
for i in range(n): | ||
|
@@ -171,7 +167,7 @@ def BalancedRelationPartition(edges, n): | |
Whether there exists some relations belongs to multiple partitions | ||
""" | ||
heads, rels, tails = edges | ||
print('relation partition {} edges into {} parts'.format(len(heads), n)) | ||
print(f'relation partition {len(heads)} edges into {n} parts') | ||
Comment on lines
-174
to
+170
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
uniq, cnts = np.unique(rels, return_counts=True) | ||
idx = np.flip(np.argsort(cnts)) | ||
cnts = cnts[idx] | ||
|
@@ -180,10 +176,7 @@ def BalancedRelationPartition(edges, n): | |
edge_cnts = np.zeros(shape=(n,), dtype=np.int64) | ||
rel_cnts = np.zeros(shape=(n,), dtype=np.int64) | ||
rel_dict = {} | ||
rel_parts = [] | ||
for _ in range(n): | ||
rel_parts.append([]) | ||
|
||
rel_parts = [[] for _ in range(n)] | ||
max_edges = (len(rels) // n) + 1 | ||
num_cross_part = 0 | ||
for i in range(len(cnts)): | ||
|
@@ -210,8 +203,8 @@ def BalancedRelationPartition(edges, n): | |
rel_dict[r] = r_parts | ||
|
||
for i, edge_cnt in enumerate(edge_cnts): | ||
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i])) | ||
print('{}/{} duplicated relation across partitions'.format(num_cross_part, len(cnts))) | ||
print(f'part {i} has {edge_cnt} edges and {rel_cnts[i]} relations') | ||
print(f'{num_cross_part}/{len(cnts)} duplicated relation across partitions') | ||
|
||
parts = [] | ||
for i in range(n): | ||
|
@@ -259,7 +252,7 @@ def RandomPartition(edges, n): | |
Edges of each partition | ||
""" | ||
heads, rels, tails = edges | ||
print('random partition {} edges into {} parts'.format(len(heads), n)) | ||
print(f'random partition {len(heads)} edges into {n} parts') | ||
Comment on lines
-262
to
+255
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
idx = np.random.permutation(len(heads)) | ||
heads[:] = heads[idx] | ||
rels[:] = rels[idx] | ||
|
@@ -271,7 +264,7 @@ def RandomPartition(edges, n): | |
start = part_size * i | ||
end = min(part_size * (i + 1), len(idx)) | ||
parts.append(idx[start:end]) | ||
print('part {} has {} edges'.format(i, len(parts[-1]))) | ||
print(f'part {i} has {len(parts[-1])} edges') | ||
return parts | ||
|
||
def ConstructGraph(edges, n_entities, args): | ||
|
@@ -624,7 +617,7 @@ def get_edges(self, eval_type): | |
elif eval_type == 'test': | ||
return self.test | ||
else: | ||
raise Exception('get invalid type: ' + eval_type) | ||
raise Exception(f'get invalid type: {eval_type}') | ||
Comment on lines
-627
to
+620
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def create_sampler(self, eval_type, batch_size, neg_sample_size, neg_chunk_size, | ||
filter_false_neg, mode='head', num_workers=32, rank=0, ranks=1): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
_download_and_extract
refactored with the following changes:use-fstring-for-formatting
)