Skip to content

Commit

Permalink
Merge pull request #90 from togethercomputer/local-data
Browse files Browse the repository at this point in the history
Local data
  • Loading branch information
mauriceweber authored Dec 1, 2023
2 parents 26c5417 + 04270d9 commit 8cddd10
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 12 deletions.
21 changes: 14 additions & 7 deletions app/src/artifacts/downloaders/ccnet_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from tqdm import tqdm
import multiprocessing as mp
from multiprocessing.pool import Pool
from urllib.parse import urlparse
import os

from utilities.io import Reader, Writer
from utilities.io.s3 import init_client


class CCNetDownloader(object):
r""" TODO: docstring """
r"""
This class downloads / loads ccnet data and writes it to a jsonl file.
"""

dataset_name = "ccnet"

Expand Down Expand Up @@ -148,12 +151,16 @@ def _load_snapshot(
"tail": int(num_samples * 0.7)
}

s3_client = init_client(
endpoint_url=self._endpoint_url,
signature_version="s3v4",
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")
)
if urlparse(self._cc_input_base_uri).scheme == "s3":
s3_client = init_client(
endpoint_url=self._endpoint_url,
signature_version="s3v4",
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")
)
else:
s3_client = None

reader = Reader(
schema=[("raw_content", str), ("language", str)],
s3_client=s3_client
Expand Down
30 changes: 28 additions & 2 deletions app/src/bloomfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __parse_listings(self):
return uris

@staticmethod
def _load_file(uri, client) -> Tuple[ReadStatus, io.BytesIO]:
def __load_from_s3(uri, client):
try:
streaming_body = client.get_object(
Bucket=uri.netloc, Key=uri.path.lstrip("/")
Expand All @@ -193,6 +193,33 @@ def _load_file(uri, client) -> Tuple[ReadStatus, io.BytesIO]:
buffer = None
is_success = False

return is_success, msg, buffer

@staticmethod
def __load_from_disk(uri):
try:
with open(uri.path, "rb") as f:
buffer = io.BytesIO(f.read())
msg = f"__DISK_URI_READ_SUCCESS__ success reading {uri.path}"
is_success = True
except Exception as e:
msg = (
f"__DISK_URI_READ_ERROR__ failed reading {uri.path}: "
f"caught exception {e.__class__.__name__}: {e}"
)
buffer = None
is_success = False

return is_success, msg, buffer

def _load_file(self, uri, client) -> Tuple[ReadStatus, io.BytesIO]:
if uri.scheme == "s3":
is_success, msg, buffer = self.__load_from_s3(uri, client)
elif uri.scheme == "file":
is_success, msg, buffer = self.__load_from_disk(uri)
else:
raise ValueError(f"Unknown scheme {uri.scheme}")

read_status = ReadStatus(
is_success=is_success, msg=msg, uri=uri.geturl()
)
Expand Down Expand Up @@ -394,7 +421,6 @@ def __parallel_run(self, input_uris):
def run(self):
start_time = dt.now()
print(f"start @ {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
# self.__threaded_run(input_uris=self.__parse_listings())
self.__parallel_run(input_uris=self.__parse_listings())
end_time = dt.now()
print(f"end @ {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
Expand Down
9 changes: 6 additions & 3 deletions app/src/prep_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def nullable_string(val):
help="Number of samples to use for classifiers"
)
parser.add_argument(
"--endpoint_url", type=nullable_string,
default=None,
"--endpoint_url", type=nullable_string, default=None,
help="endpoint url where the s3 bucket is exposed."
)

Expand All @@ -88,7 +87,11 @@ def main(artifacts_dir: str, cc_input: str, cc_input_base_uri: str,
classifiers_num_samples: int, max_samples_per_book: int,
max_paragraphs_per_book_sample: int
):
max_workers = min(max_workers, os.cpu_count() - 2)
if max_workers is None:
max_workers = os.cpu_count() - 2
else:
max_workers = min(max_workers, os.cpu_count() - 2)

# parse config
num_samples = max(dsir_num_samples, classifiers_num_samples)

Expand Down

0 comments on commit 8cddd10

Please sign in to comment.