diff --git a/app/src/artifacts/downloaders/ccnet_downloader.py b/app/src/artifacts/downloaders/ccnet_downloader.py index 4154bac..bcfd1ef 100644 --- a/app/src/artifacts/downloaders/ccnet_downloader.py +++ b/app/src/artifacts/downloaders/ccnet_downloader.py @@ -6,6 +6,7 @@ from tqdm import tqdm import multiprocessing as mp from multiprocessing.pool import Pool +from urllib.parse import urlparse import os from utilities.io import Reader, Writer @@ -13,7 +14,9 @@ class CCNetDownloader(object): - r""" TODO: docstring """ + r""" + This class downloads / loads ccnet data and writes it to a jsonl file. + """ dataset_name = "ccnet" @@ -148,12 +151,16 @@ def _load_snapshot( "tail": int(num_samples * 0.7) } - s3_client = init_client( - endpoint_url=self._endpoint_url, - signature_version="s3v4", - aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), - aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY") - ) + if urlparse(self._cc_input_base_uri).scheme == "s3": + s3_client = init_client( + endpoint_url=self._endpoint_url, + signature_version="s3v4", + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY") + ) + else: + s3_client = None + reader = Reader( schema=[("raw_content", str), ("language", str)], s3_client=s3_client diff --git a/app/src/bloomfilter.py b/app/src/bloomfilter.py index ffc81d8..e3a648f 100644 --- a/app/src/bloomfilter.py +++ b/app/src/bloomfilter.py @@ -177,7 +177,7 @@ def __parse_listings(self): return uris @staticmethod - def _load_file(uri, client) -> Tuple[ReadStatus, io.BytesIO]: + def __load_from_s3(uri, client): try: streaming_body = client.get_object( Bucket=uri.netloc, Key=uri.path.lstrip("/") @@ -193,6 +193,33 @@ def _load_file(uri, client) -> Tuple[ReadStatus, io.BytesIO]: buffer = None is_success = False + return is_success, msg, buffer + + @staticmethod + def __load_from_disk(uri): + try: + with open(uri.path, "rb") as f: + buffer = io.BytesIO(f.read()) + msg = f"__DISK_URI_READ_SUCCESS__ success reading {uri.path}" + is_success = True + except Exception as e: + msg = ( + f"__DISK_URI_READ_ERROR__ failed reading {uri.path}: " + f"caught exception {e.__class__.__name__}: {e}" + ) + buffer = None + is_success = False + + return is_success, msg, buffer + + def _load_file(self, uri, client) -> Tuple[ReadStatus, io.BytesIO]: + if uri.scheme == "s3": + is_success, msg, buffer = self.__load_from_s3(uri, client) + elif uri.scheme == "file": + is_success, msg, buffer = self.__load_from_disk(uri) + else: + raise ValueError(f"Unknown scheme {uri.scheme}") + read_status = ReadStatus( is_success=is_success, msg=msg, uri=uri.geturl() ) @@ -394,7 +421,6 @@ def __parallel_run(self, input_uris): def run(self): start_time = dt.now() print(f"start @ {start_time.strftime('%Y-%m-%d %H:%M:%S')}") - # self.__threaded_run(input_uris=self.__parse_listings()) self.__parallel_run(input_uris=self.__parse_listings()) end_time = dt.now() print(f"end @ {end_time.strftime('%Y-%m-%d %H:%M:%S')}") diff --git a/app/src/prep_artifacts.py b/app/src/prep_artifacts.py index d0064f4..cfd52cc 100644 --- a/app/src/prep_artifacts.py +++ b/app/src/prep_artifacts.py @@ -63,8 +63,7 @@ def nullable_string(val): help="Number of samples to use for classifiers" ) parser.add_argument( - "--endpoint_url", type=nullable_string, - default=None, + "--endpoint_url", type=nullable_string, default=None, help="endpoint url where the s3 bucket is exposed." ) @@ -88,7 +87,11 @@ def main(artifacts_dir: str, cc_input: str, cc_input_base_uri: str, classifiers_num_samples: int, max_samples_per_book: int, max_paragraphs_per_book_sample: int ): - max_workers = min(max_workers, os.cpu_count() - 2) + if max_workers is None: + max_workers = os.cpu_count() - 2 + else: + max_workers = min(max_workers, os.cpu_count() - 2) + # parse config num_samples = max(dsir_num_samples, classifiers_num_samples)