Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery refactored xin branch #1

Open
wants to merge 1 commit into
base: xin
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 31 additions & 40 deletions apps/kg/dataloader/KGDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _download_and_extract(url, path, filename):
os.makedirs(path, exist_ok=True)
f_remote = requests.get(url, stream=True)
sz = f_remote.headers.get('content-length')
assert f_remote.status_code == 200, 'fail to open {}'.format(url)
assert f_remote.status_code == 200, f'fail to open {url}'
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _download_and_extract refactored with the following changes:

with open(fn, 'wb') as writer:
for chunk in f_remote.iter_content(chunk_size=1024*1024):
writer.write(chunk)
Expand Down Expand Up @@ -65,8 +65,7 @@ def _parse_srd_format(format):

def _file_line(path):
with open(path) as f:
for i, l in enumerate(f):
pass
pass
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _file_line refactored with the following changes:

return i + 1

class KGDataset:
Expand Down Expand Up @@ -117,7 +116,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
if path is None:
return None

print('Reading {} triples....'.format(mode))
print(f'Reading {mode} triples....')
Comment on lines -120 to +119
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDataset.read_triple refactored with the following changes:

heads = []
tails = []
rels = []
Expand All @@ -134,7 +133,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
print(f'Finished. Read {len(heads)} {mode} triples.')

return (heads, rels, tails)

Expand Down Expand Up @@ -164,7 +163,7 @@ def read_triple(self, path, mode):
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
print(f'Reading {mode} triples....')
Comment on lines -167 to +166
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function PartitionKGDataset.read_triple refactored with the following changes:

with open(path) as f:
for line in f:
h, r, t = line.strip().split('\t')
Expand All @@ -175,7 +174,7 @@ def read_triple(self, path, mode):
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
print(f'Finished. Read {len(heads)} {mode} triples.')

return (heads, rels, tails)

Expand All @@ -195,11 +194,11 @@ class KGDatasetFB15k(KGDataset):
'''
def __init__(self, path, name='FB15k'):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
url = f'https://data.dgl.ai/dataset/{name}.zip'

if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
_download_and_extract(url, path, f'{name}.zip')
Comment on lines -198 to +201
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetFB15k.__init__ refactored with the following changes:

self.path = os.path.join(path, name)

super(KGDatasetFB15k, self).__init__(os.path.join(self.path, 'entities.dict'),
Expand All @@ -224,11 +223,11 @@ class KGDatasetFB15k237(KGDataset):
'''
def __init__(self, path, name='FB15k-237'):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
url = f'https://data.dgl.ai/dataset/{name}.zip'

if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
_download_and_extract(url, path, f'{name}.zip')
Comment on lines -227 to +230
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetFB15k237.__init__ refactored with the following changes:

self.path = os.path.join(path, name)

super(KGDatasetFB15k237, self).__init__(os.path.join(self.path, 'entities.dict'),
Expand All @@ -253,11 +252,11 @@ class KGDatasetWN18(KGDataset):
'''
def __init__(self, path, name='wn18'):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
url = f'https://data.dgl.ai/dataset/{name}.zip'

if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
_download_and_extract(url, path, f'{name}.zip')
Comment on lines -256 to +259
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetWN18.__init__ refactored with the following changes:

self.path = os.path.join(path, name)

super(KGDatasetWN18, self).__init__(os.path.join(self.path, 'entities.dict'),
Expand All @@ -282,11 +281,11 @@ class KGDatasetWN18rr(KGDataset):
'''
def __init__(self, path, name='wn18rr'):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
url = f'https://data.dgl.ai/dataset/{name}.zip'

if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, name + '.zip')
_download_and_extract(url, path, f'{name}.zip')
Comment on lines -285 to +288
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetWN18rr.__init__ refactored with the following changes:

self.path = os.path.join(path, name)

super(KGDatasetWN18rr, self).__init__(os.path.join(self.path, 'entities.dict'),
Expand All @@ -310,11 +309,11 @@ class KGDatasetFreebase(KGDataset):
'''
def __init__(self, path, name='Freebase'):
self.name = name
url = 'https://data.dgl.ai/dataset/{}.zip'.format(name)
url = f'https://data.dgl.ai/dataset/{name}.zip'

if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, '{}.zip'.format(name))
_download_and_extract(url, path, f'{name}.zip')
Comment on lines -313 to +316
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetFreebase.__init__ refactored with the following changes:

self.path = os.path.join(path, name)

super(KGDatasetFreebase, self).__init__(os.path.join(self.path, 'entity2id.txt'),
Expand All @@ -337,7 +336,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=None):
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
print(f'Reading {mode} triples....')
Comment on lines -340 to +339
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetFreebase.read_triple refactored with the following changes:

with open(path) as f:
if skip_first_line:
_ = f.readline()
Expand All @@ -350,7 +349,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=None):
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
print(f'Finished. Read {len(heads)} {mode} triples.')
return (heads, rels, tails)

class KGDatasetUDDRaw(KGDataset):
Expand All @@ -369,8 +368,7 @@ class KGDatasetUDDRaw(KGDataset):
def __init__(self, path, name, files, format):
self.name = name
for f in files:
assert os.path.exists(os.path.join(path, f)), \
'File {} now exist in {}'.format(f, path)
assert os.path.exists(os.path.join(path, f)), f'File {f} now exist in {path}'
Comment on lines -372 to +371
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetUDDRaw.__init__ refactored with the following changes:


assert len(format) == 3
format = _parse_srd_format(format)
Expand Down Expand Up @@ -437,8 +435,7 @@ class KGDatasetUDD(KGDataset):
def __init__(self, path, name, files, format):
self.name = name
for f in files:
assert os.path.exists(os.path.join(path, f)), \
'File {} now exist in {}'.format(f, path)
assert os.path.exists(os.path.join(path, f)), f'File {f} now exist in {path}'
Comment on lines -440 to +438
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetUDD.__init__ refactored with the following changes:


format = _parse_srd_format(format)
if len(files) == 3:
Expand All @@ -458,22 +455,22 @@ def __init__(self, path, name, files, format):
def read_entity(self, entity_path):
n_entities = 0
with open(entity_path) as f_ent:
for line in f_ent:
for _ in f_ent:
Comment on lines -461 to +458
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetUDD.read_entity refactored with the following changes:

n_entities += 1
return None, n_entities

def read_relation(self, relation_path):
n_relations = 0
with open(relation_path) as f_rel:
for line in f_rel:
for _ in f_rel:
Comment on lines -468 to +465
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetUDD.read_relation refactored with the following changes:

n_relations += 1
return None, n_relations

def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
print(f'Reading {mode} triples....')
Comment on lines -476 to +473
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function KGDatasetUDD.read_triple refactored with the following changes:

with open(path) as f:
if skip_first_line:
_ = f.readline()
Expand All @@ -486,7 +483,7 @@ def read_triple(self, path, mode, skip_first_line=False, format=[0,1,2]):
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
print(f'Finished. Read {len(heads)} {mode} triples.')
return (heads, rels, tails)

def get_dataset(data_path, data_name, format_str, files=None):
Expand All @@ -502,7 +499,7 @@ def get_dataset(data_path, data_name, format_str, files=None):
elif data_name == 'wn18rr':
dataset = KGDatasetWN18rr(data_path)
else:
assert False, "Unknown dataset {}".format(data_name)
assert False, f"Unknown dataset {data_name}"
Comment on lines -505 to +502
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_dataset refactored with the following changes:

elif format_str.startswith('raw_udd'):
# user defined dataset
format = format_str[8:]
Expand All @@ -512,13 +509,13 @@ def get_dataset(data_path, data_name, format_str, files=None):
format = format_str[4:]
dataset = KGDatasetUDD(data_path, data_name, files, format)
else:
assert False, "Unknown format {}".format(format_str)
assert False, f"Unknown format {format_str}"

return dataset


def get_partition_dataset(data_path, data_name, part_id):
part_name = os.path.join(data_name, 'partition_'+str(part_id))
part_name = os.path.join(data_name, f'partition_{str(part_id)}')
Comment on lines -521 to +518
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_partition_dataset refactored with the following changes:

path = os.path.join(data_path, part_name)

if not os.path.exists(path):
Expand Down Expand Up @@ -547,19 +544,15 @@ def get_partition_dataset(data_path, data_name, part_id):

partition_book = []
with open(partition_book_path) as f:
for line in f:
partition_book.append(int(line))

partition_book.extend(int(line) for line in f)
local_to_global = []
with open(local2global_path) as f:
for line in f:
local_to_global.append(int(line))

local_to_global.extend(int(line) for line in f)
return dataset, partition_book, local_to_global


def get_server_partition_dataset(data_path, data_name, part_id):
part_name = os.path.join(data_name, 'partition_'+str(part_id))
part_name = os.path.join(data_name, f'partition_{str(part_id)}')
Comment on lines -562 to +555
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_server_partition_dataset refactored with the following changes:

path = os.path.join(data_path, part_name)

if not os.path.exists(path):
Expand Down Expand Up @@ -589,9 +582,7 @@ def get_server_partition_dataset(data_path, data_name, part_id):

local_to_global = []
with open(local2global_path) as f:
for line in f:
local_to_global.append(int(line))

local_to_global.extend(int(line) for line in f)
global_to_local = [0] * n_entities
for i in range(len(local_to_global)):
global_id = local_to_global[i]
Expand Down
31 changes: 12 additions & 19 deletions apps/kg/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def SoftRelationPartition(edges, n, threshold=0.05):
Whether there exists some relations belongs to multiple partitions
"""
heads, rels, tails = edges
print('relation partition {} edges into {} parts'.format(len(heads), n))
print(f'relation partition {len(heads)} edges into {n} parts')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SoftRelationPartition refactored with the following changes:

uniq, cnts = np.unique(rels, return_counts=True)
idx = np.flip(np.argsort(cnts))
cnts = cnts[idx]
Expand All @@ -73,16 +73,12 @@ def SoftRelationPartition(edges, n, threshold=0.05):
edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_dict = {}
rel_parts = []
cross_rel_part = []
for _ in range(n):
rel_parts.append([])

rel_parts = [[] for _ in range(n)]
large_threshold = int(len(rels) * threshold)
capacity_per_partition = int(len(rels) / n)
# ensure any relation larger than the partition capacity will be split
large_threshold = capacity_per_partition if capacity_per_partition < large_threshold \
else large_threshold
large_threshold = min(capacity_per_partition, large_threshold)
num_cross_part = 0
for i in range(len(cnts)):
cnt = cnts[i]
Expand All @@ -108,8 +104,8 @@ def SoftRelationPartition(edges, n, threshold=0.05):
rel_dict[r] = r_parts

for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
print('{}/{} duplicated relation across partitions'.format(num_cross_part, len(cnts)))
print(f'part {i} has {edge_cnt} edges and {rel_cnts[i]} relations')
print(f'{num_cross_part}/{len(cnts)} duplicated relation across partitions')

parts = []
for i in range(n):
Expand Down Expand Up @@ -171,7 +167,7 @@ def BalancedRelationPartition(edges, n):
Whether there exists some relations belongs to multiple partitions
"""
heads, rels, tails = edges
print('relation partition {} edges into {} parts'.format(len(heads), n))
print(f'relation partition {len(heads)} edges into {n} parts')
Comment on lines -174 to +170
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function BalancedRelationPartition refactored with the following changes:

uniq, cnts = np.unique(rels, return_counts=True)
idx = np.flip(np.argsort(cnts))
cnts = cnts[idx]
Expand All @@ -180,10 +176,7 @@ def BalancedRelationPartition(edges, n):
edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_dict = {}
rel_parts = []
for _ in range(n):
rel_parts.append([])

rel_parts = [[] for _ in range(n)]
max_edges = (len(rels) // n) + 1
num_cross_part = 0
for i in range(len(cnts)):
Expand All @@ -210,8 +203,8 @@ def BalancedRelationPartition(edges, n):
rel_dict[r] = r_parts

for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
print('{}/{} duplicated relation across partitions'.format(num_cross_part, len(cnts)))
print(f'part {i} has {edge_cnt} edges and {rel_cnts[i]} relations')
print(f'{num_cross_part}/{len(cnts)} duplicated relation across partitions')

parts = []
for i in range(n):
Expand Down Expand Up @@ -259,7 +252,7 @@ def RandomPartition(edges, n):
Edges of each partition
"""
heads, rels, tails = edges
print('random partition {} edges into {} parts'.format(len(heads), n))
print(f'random partition {len(heads)} edges into {n} parts')
Comment on lines -262 to +255
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RandomPartition refactored with the following changes:

idx = np.random.permutation(len(heads))
heads[:] = heads[idx]
rels[:] = rels[idx]
Expand All @@ -271,7 +264,7 @@ def RandomPartition(edges, n):
start = part_size * i
end = min(part_size * (i + 1), len(idx))
parts.append(idx[start:end])
print('part {} has {} edges'.format(i, len(parts[-1])))
print(f'part {i} has {len(parts[-1])} edges')
return parts

def ConstructGraph(edges, n_entities, args):
Expand Down Expand Up @@ -624,7 +617,7 @@ def get_edges(self, eval_type):
elif eval_type == 'test':
return self.test
else:
raise Exception('get invalid type: ' + eval_type)
raise Exception(f'get invalid type: {eval_type}')
Comment on lines -627 to +620
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function EvalDataset.get_edges refactored with the following changes:


def create_sampler(self, eval_type, batch_size, neg_sample_size, neg_chunk_size,
filter_false_neg, mode='head', num_workers=32, rank=0, ranks=1):
Expand Down
Loading