diff --git a/component_sdk/python/kfp_component/__init__.py b/component_sdk/python/kfp_component/__init__.py index 689556d9213..1bd4496142b 100644 --- a/component_sdk/python/kfp_component/__init__.py +++ b/component_sdk/python/kfp_component/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import launcher, core +from . import launcher, core, google diff --git a/component_sdk/python/kfp_component/google/__init__.py b/component_sdk/python/kfp_component/google/__init__.py index c2fc82ab83f..e8a8d80fe37 100644 --- a/component_sdk/python/kfp_component/google/__init__.py +++ b/component_sdk/python/kfp_component/google/__init__.py @@ -10,4 +10,6 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. + +from . import ml_engine, dataflow \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/dataflow/__init__.py b/component_sdk/python/kfp_component/google/dataflow/__init__.py new file mode 100644 index 00000000000..0ede8f7495d --- /dev/null +++ b/component_sdk/python/kfp_component/google/dataflow/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._launch_template import launch_template +from ._launch_python import launch_python diff --git a/component_sdk/python/kfp_component/google/dataflow/_client.py b/component_sdk/python/kfp_component/google/dataflow/_client.py new file mode 100644 index 00000000000..1d5e38c3274 --- /dev/null +++ b/component_sdk/python/kfp_component/google/dataflow/_client.py @@ -0,0 +1,58 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import googleapiclient.discovery as discovery +from googleapiclient import errors + +class DataflowClient: + def __init__(self): + self._df = discovery.build('dataflow', 'v1b3') + + def launch_template(self, project_id, gcs_path, location, + validate_only, launch_parameters): + return self._df.projects().templates().launch( + projectId = project_id, + gcsPath = gcs_path, + location = location, + validateOnly = validate_only, + body = launch_parameters + ).execute() + + def get_job(self, project_id, job_id, location=None, view=None): + return self._df.projects().jobs().get( + projectId = project_id, + jobId = job_id, + location = location, + view = view + ).execute() + + def cancel_job(self, project_id, job_id, location): + return self._df.projects().jobs().update( + projectId = project_id, + jobId = job_id, + location = location, + body = { + 'requestedState': 'JOB_STATE_CANCELLED' + } + ).execute() + + def list_aggregated_jobs(self, project_id, filter=None, + view=None, page_size=None, page_token=None, location=None): + return self._df.projects().jobs().aggregated( + projectId = project_id, + filter = filter, + view = view, + pageSize = page_size, + pageToken = page_token, + location = location).execute() diff --git a/component_sdk/python/kfp_component/google/dataflow/_common_ops.py b/component_sdk/python/kfp_component/google/dataflow/_common_ops.py new file mode 100644 index 00000000000..d5a119b834c --- /dev/null +++ b/component_sdk/python/kfp_component/google/dataflow/_common_ops.py @@ -0,0 +1,121 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +import json +import os +import tempfile + +from kfp_component.core import display +from .. import common as gcp_common +from ..storage import download_blob, parse_blob_path, is_gcs_path + +_JOB_SUCCESSFUL_STATES = ['JOB_STATE_DONE', 'JOB_STATE_UPDATED', 'JOB_STATE_DRAINED'] +_JOB_FAILED_STATES = ['JOB_STATE_STOPPED', 'JOB_STATE_FAILED', 'JOB_STATE_CANCELLED'] +_JOB_TERMINATED_STATES = _JOB_SUCCESSFUL_STATES + _JOB_FAILED_STATES + +def generate_job_name(job_name, context_id): + """Generates a stable job name in the job context. + + If user provided ``job_name`` has value, the function will use it + as a prefix and appends first 8 characters of ``context_id`` to + make the name unique across contexts. If the ``job_name`` is empty, + it will use ``job-{context_id}`` as the job name. + """ + if job_name: + return '{}-{}'.format( + gcp_common.normalize_name(job_name), + context_id[:8]) + + return 'job-{}'.format(context_id) + +def get_job_by_name(df_client, project_id, job_name, location=None): + """Gets a job by its name. + + The function lists all jobs under a project or a region location. + Compares their names with the ``job_name`` and return the job + once it finds a match. If none of the jobs matches, it returns + ``None``. + """ + page_token = None + while True: + response = df_client.list_aggregated_jobs(project_id, + page_size=50, page_token=page_token, location=location) + for job in response.get('jobs', []): + name = job.get('name', None) + if job_name == name: + return job + page_token = response.get('nextPageToken', None) + if not page_token: + return None + +def wait_for_job_done(df_client, project_id, job_id, location=None, wait_interval=30): + while True: + job = df_client.get_job(project_id, job_id, location=location) + state = job.get('currentState', None) + if is_job_done(state): + return job + elif is_job_terminated(state): + # Terminated with error state + raise RuntimeError('Job {} failed with error state: {}.'.format( + job_id, + state + )) + else: + logging.info('Job {} is in pending state {}.' + ' Waiting for {} seconds for next poll.'.format( + job_id, + state, + wait_interval + )) + time.sleep(wait_interval) + +def wait_and_dump_job(df_client, project_id, location, job, + wait_interval): + display_job_link(project_id, job) + job_id = job.get('id') + job = wait_for_job_done(df_client, project_id, job_id, + location, wait_interval) + dump_job(job) + return job + +def is_job_terminated(job_state): + return job_state in _JOB_TERMINATED_STATES + +def is_job_done(job_state): + return job_state in _JOB_SUCCESSFUL_STATES + +def display_job_link(project_id, job): + location = job.get('location') + job_id = job.get('id') + display.display(display.Link( + href = 'https://console.cloud.google.com/dataflow/' + 'jobsDetail/locations/{}/jobs/{}?project={}'.format( + location, job_id, project_id), + text = 'Job Details' + )) + +def dump_job(job): + gcp_common.dump_file('/tmp/output/job.json', json.dumps(job)) + +def stage_file(local_or_gcs_path): + if not is_gcs_path(local_or_gcs_path): + return local_or_gcs_path + _, blob_path = parse_blob_path(local_or_gcs_path) + file_name = os.path.basename(blob_path) + local_file_path = os.path.join(tempfile.mkdtemp(), file_name) + download_blob(local_or_gcs_path, local_file_path) + return local_file_path + diff --git a/component_sdk/python/kfp_component/google/dataflow/_launch_python.py b/component_sdk/python/kfp_component/google/dataflow/_launch_python.py new file mode 100644 index 00000000000..707fea78f8f --- /dev/null +++ b/component_sdk/python/kfp_component/google/dataflow/_launch_python.py @@ -0,0 +1,107 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import subprocess +import re +import logging + +from kfp_component.core import KfpExecutionContext +from ._client import DataflowClient +from .. import common as gcp_common +from ._common_ops import (generate_job_name, get_job_by_name, + wait_and_dump_job, stage_file) +from ._process import Process + +def launch_python(python_file_path, project_id, requirements_file_path=None, + location=None, job_name_prefix=None, args=[], wait_interval=30): + """Launch a self-executing beam python file. + + Args: + python_file_path (str): The gcs or local path to the python file to run. + project_id (str): The ID of the parent project. + requirements_file_path (str): Optional, the gcs or local path to the pip + requirements file. + location (str): The regional endpoint to which to direct the + request. + job_name_prefix (str): Optional. The prefix of the genrated job + name. If not provided, the method will generated a random name. + args (list): The list of args to pass to the python file. + wait_interval (int): The wait seconds between polling. + Returns: + The completed job. + """ + df_client = DataflowClient() + job_id = None + def cancel(): + if job_id: + df_client.cancel_job( + project_id, + job_id, + location + ) + with KfpExecutionContext(on_cancel=cancel) as ctx: + job_name = generate_job_name( + job_name_prefix, + ctx.context_id()) + # We will always generate unique name for the job. We expect + # job with same name was created in previous tries from the same + # pipeline run. + job = get_job_by_name(df_client, project_id, job_name, + location) + if job: + return wait_and_dump_job(df_client, project_id, location, job, + wait_interval) + + _install_requirements(requirements_file_path) + python_file_path = stage_file(python_file_path) + cmd = _prepare_cmd(project_id, location, job_name, python_file_path, + args) + sub_process = Process(cmd) + for line in sub_process.read_lines(): + job_id = _extract_job_id(line) + if job_id: + logging.info('Found job id {}'.format(job_id)) + break + sub_process.wait_and_check() + if not job_id: + logging.warning('No dataflow job was found when ' + 'running the python file.') + return None + job = df_client.get_job(project_id, job_id, + location=location) + return wait_and_dump_job(df_client, project_id, location, job, + wait_interval) + +def _prepare_cmd(project_id, location, job_name, python_file_path, args): + dataflow_args = [ + '--runner', 'dataflow', + '--project', project_id, + '--job-name', job_name] + if location: + dataflow_args += ['--location', location] + return (['python2', '-u', python_file_path] + + dataflow_args + args) + +def _extract_job_id(line): + job_id_pattern = re.compile( + br'.*console.cloud.google.com/dataflow.*/jobs/([a-z|0-9|A-Z|\-|\_]+).*') + matched_job_id = job_id_pattern.search(line or '') + if matched_job_id: + return matched_job_id.group(1).decode() + return None + +def _install_requirements(requirements_file_path): + if not requirements_file_path: + return + requirements_file_path = stage_file(requirements_file_path) + subprocess.run(['pip2', 'install', '-r', requirements_file_path]) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/dataflow/_launch_template.py b/component_sdk/python/kfp_component/google/dataflow/_launch_template.py new file mode 100644 index 00000000000..6d87dd730a1 --- /dev/null +++ b/component_sdk/python/kfp_component/google/dataflow/_launch_template.py @@ -0,0 +1,74 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import re +import time + +from kfp_component.core import KfpExecutionContext +from ._client import DataflowClient +from .. import common as gcp_common +from ._common_ops import (generate_job_name, get_job_by_name, + wait_and_dump_job) + +def launch_template(project_id, gcs_path, launch_parameters, + location=None, job_name_prefix=None, validate_only=None, + wait_interval=30): + """Launchs a dataflow job from template. + + Args: + project_id (str): Required. The ID of the Cloud Platform project + that the job belongs to. + gcs_path (str): Required. A Cloud Storage path to the template + from which to create the job. Must be valid Cloud + Storage URL, beginning with 'gs://'. + launch_parameters (dict): Parameters to provide to the template + being launched. Schema defined in + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/LaunchTemplateParameters. + `jobName` will be replaced by generated name. + location (str): The regional endpoint to which to direct the + request. + job_name_prefix (str): Optional. The prefix of the genrated job + name. If not provided, the method will generated a random name. + validate_only (boolean): If true, the request is validated but + not actually executed. Defaults to false. + wait_interval (int): The wait seconds between polling. + + Returns: + The completed job. + """ + df_client = DataflowClient() + job_id = None + def cancel(): + if job_id: + df_client.cancel_job( + project_id, + job_id, + location + ) + with KfpExecutionContext(on_cancel=cancel) as ctx: + job_name = generate_job_name( + job_name_prefix, + ctx.context_id()) + print(job_name) + job = get_job_by_name(df_client, project_id, job_name, + location) + if not job: + launch_parameters['jobName'] = job_name + response = df_client.launch_template(project_id, gcs_path, + location, validate_only, launch_parameters) + job = response.get('job') + return wait_and_dump_job(df_client, project_id, location, job, + wait_interval) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/dataflow/_process.py b/component_sdk/python/kfp_component/google/dataflow/_process.py new file mode 100644 index 00000000000..85ee7c41e21 --- /dev/null +++ b/component_sdk/python/kfp_component/google/dataflow/_process.py @@ -0,0 +1,40 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import subprocess +import logging + +class Process: + def __init__(self, cmd): + self._cmd = cmd + self.process = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + close_fds=True, + shell=False) + + def read_lines(self): + # stdout will end with empty bytes when process exits. + for line in iter(self.process.stdout.readline, b''): + logging.info('subprocess: {}'.format(line)) + yield line + + def wait_and_check(self): + for _ in self.read_lines(): + pass + self.process.stdout.close() + return_code = self.process.wait() + logging.info('Subprocess exit with code {}.'.format( + return_code)) + if return_code: + raise subprocess.CalledProcessError(return_code, self._cmd) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/storage/__init__.py b/component_sdk/python/kfp_component/google/storage/__init__.py new file mode 100644 index 00000000000..0cbe29b7ed0 --- /dev/null +++ b/component_sdk/python/kfp_component/google/storage/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._download_blob import download_blob +from ._common_ops import parse_blob_path, is_gcs_path diff --git a/component_sdk/python/kfp_component/google/storage/_common_ops.py b/component_sdk/python/kfp_component/google/storage/_common_ops.py new file mode 100644 index 00000000000..ea27cae3543 --- /dev/null +++ b/component_sdk/python/kfp_component/google/storage/_common_ops.py @@ -0,0 +1,41 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +def is_gcs_path(path): + """Check if the path is a gcs path""" + return path.startswith('gs://') + +def parse_blob_path(path): + """Parse a gcs path into bucket name and blob name + + Args: + path (str): the path to parse. + + Returns: + (bucket name in the path, blob name in the path) + + Raises: + ValueError if the path is not a valid gcs blob path. + + Example: + + `bucket_name, blob_name = parse_blob_path('gs://foo/bar')` + `bucket_name` is `foo` and `blob_name` is `bar` + """ + match = re.match('gs://([^/]+)/(.+)$', path) + if match: + return match.group(1), match.group(2) + raise ValueError('Path {} is invalid blob path.'.format( + path)) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/storage/_download_blob.py b/component_sdk/python/kfp_component/google/storage/_download_blob.py new file mode 100644 index 00000000000..743b7fdf999 --- /dev/null +++ b/component_sdk/python/kfp_component/google/storage/_download_blob.py @@ -0,0 +1,42 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from google.cloud import storage +from ._common_ops import parse_blob_path + +def download_blob(source_blob_path, destination_file_path): + """Downloads a blob from the bucket. + + Args: + source_blob_path (str): the source blob path to download from. + destination_file_path (str): the local file path to download to. + """ + bucket_name, blob_name = parse_blob_path(source_blob_path) + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + blob = bucket.blob(blob_name) + + dirname = os.path.dirname(destination_file_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + with open(destination_file_path, 'wb+') as f: + blob.download_to_file(f) + + logging.info('Blob {} downloaded to {}.'.format( + source_blob_path, + destination_file_path)) \ No newline at end of file diff --git a/component_sdk/python/tests/google/dataflow/__init__.py b/component_sdk/python/tests/google/dataflow/__init__.py new file mode 100644 index 00000000000..c2fc82ab83f --- /dev/null +++ b/component_sdk/python/tests/google/dataflow/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/component_sdk/python/tests/google/dataflow/test__launch_python.py b/component_sdk/python/tests/google/dataflow/test__launch_python.py new file mode 100644 index 00000000000..246816dd106 --- /dev/null +++ b/component_sdk/python/tests/google/dataflow/test__launch_python.py @@ -0,0 +1,82 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest +import os + +from kfp_component.google.dataflow import launch_python + +MODULE = 'kfp_component.google.dataflow._launch_python' + +@mock.patch('kfp_component.google.dataflow._common_ops.display') +@mock.patch(MODULE + '.stage_file') +@mock.patch(MODULE + '.KfpExecutionContext') +@mock.patch(MODULE + '.DataflowClient') +@mock.patch(MODULE + '.Process') +@mock.patch(MODULE + '.subprocess') +class LaunchPythonTest(unittest.TestCase): + + def test_launch_python_succeed(self, mock_subprocess, mock_process, + mock_client, mock_context, mock_stage_file, mock_display): + mock_context().__enter__().context_id.return_value = 'ctx-1' + mock_client().list_aggregated_jobs.return_value = { + 'jobs': [] + } + mock_process().read_lines.return_value = [ + b'https://console.cloud.google.com/dataflow/locations/us-central1/jobs/job-1?project=project-1' + ] + expected_job = { + 'currentState': 'JOB_STATE_DONE' + } + mock_client().get_job.return_value = expected_job + + result = launch_python('/tmp/test.py', 'project-1') + + self.assertEqual(expected_job, result) + + def test_launch_python_retry_succeed(self, mock_subprocess, mock_process, + mock_client, mock_context, mock_stage_file, mock_display): + mock_context().__enter__().context_id.return_value = 'ctx-1' + mock_client().list_aggregated_jobs.return_value = { + 'jobs': [{ + 'id': 'job-1', + 'name': 'test_job-ctx-1' + }] + } + expected_job = { + 'currentState': 'JOB_STATE_DONE' + } + mock_client().get_job.return_value = expected_job + + result = launch_python('/tmp/test.py', 'project-1', job_name_prefix='test-job') + + self.assertEqual(expected_job, result) + mock_process.assert_not_called() + + def test_launch_python_no_job_created(self, mock_subprocess, mock_process, + mock_client, mock_context, mock_stage_file, mock_display): + mock_context().__enter__().context_id.return_value = 'ctx-1' + mock_client().list_aggregated_jobs.return_value = { + 'jobs': [] + } + mock_process().read_lines.return_value = [ + b'no job id', + b'no job id' + ] + + result = launch_python('/tmp/test.py', 'project-1') + + self.assertEqual(None, result) + \ No newline at end of file diff --git a/component_sdk/python/tests/google/dataflow/test__launch_template.py b/component_sdk/python/tests/google/dataflow/test__launch_template.py new file mode 100644 index 00000000000..d5ceff3bf32 --- /dev/null +++ b/component_sdk/python/tests/google/dataflow/test__launch_template.py @@ -0,0 +1,106 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest +import os + +from kfp_component.google.dataflow import launch_template + +MODULE = 'kfp_component.google.dataflow._launch_template' + +@mock.patch('kfp_component.google.dataflow._common_ops.display') +@mock.patch(MODULE + '.KfpExecutionContext') +@mock.patch(MODULE + '.DataflowClient') +class LaunchTemplateTest(unittest.TestCase): + + def test_launch_template_succeed(self, mock_client, mock_context, mock_display): + mock_context().__enter__().context_id.return_value = 'context-1' + mock_client().list_aggregated_jobs.return_value = { + 'jobs': [] + } + mock_client().launch_template.return_value = { + 'job': { 'id': 'job-1' } + } + expected_job = { + 'currentState': 'JOB_STATE_DONE' + } + mock_client().get_job.return_value = expected_job + + result = launch_template('project-1', 'gs://foo/bar', { + "parameters": { + "foo": "bar" + }, + "environment": { + "zone": "us-central1" + } + }) + + self.assertEqual(expected_job, result) + mock_client().launch_template.assert_called_once() + + def test_launch_template_retry_succeed(self, + mock_client, mock_context, mock_display): + mock_context().__enter__().context_id.return_value = 'ctx-1' + # The job with same name already exists. + mock_client().list_aggregated_jobs.return_value = { + 'jobs': [{ + 'id': 'job-1', + 'name': 'test_job-ctx-1' + }] + } + pending_job = { + 'currentState': 'JOB_STATE_PENDING' + } + expected_job = { + 'currentState': 'JOB_STATE_DONE' + } + mock_client().get_job.side_effect = [pending_job, expected_job] + + result = launch_template('project-1', 'gs://foo/bar', { + "parameters": { + "foo": "bar" + }, + "environment": { + "zone": "us-central1" + } + }, job_name_prefix='test-job', wait_interval=0) + + self.assertEqual(expected_job, result) + mock_client().launch_template.assert_not_called() + + def test_launch_template_fail(self, mock_client, mock_context, mock_display): + mock_context().__enter__().context_id.return_value = 'context-1' + mock_client().list_aggregated_jobs.return_value = { + 'jobs': [] + } + mock_client().launch_template.return_value = { + 'job': { 'id': 'job-1' } + } + failed_job = { + 'currentState': 'JOB_STATE_FAILED' + } + mock_client().get_job.return_value = failed_job + + self.assertRaises(RuntimeError, + lambda: launch_template('project-1', 'gs://foo/bar', { + "parameters": { + "foo": "bar" + }, + "environment": { + "zone": "us-central1" + } + })) + + \ No newline at end of file diff --git a/component_sdk/python/tests/google/storage/__init__.py b/component_sdk/python/tests/google/storage/__init__.py new file mode 100644 index 00000000000..c2fc82ab83f --- /dev/null +++ b/component_sdk/python/tests/google/storage/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/component_sdk/python/tests/google/storage/test__common_ops.py b/component_sdk/python/tests/google/storage/test__common_ops.py new file mode 100644 index 00000000000..9d0913dd0c3 --- /dev/null +++ b/component_sdk/python/tests/google/storage/test__common_ops.py @@ -0,0 +1,40 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from kfp_component.google.storage import is_gcs_path, parse_blob_path + +class CommonOpsTest(unittest.TestCase): + + def test_is_gcs_path(self): + self.assertTrue(is_gcs_path('gs://foo')) + self.assertTrue(is_gcs_path('gs://foo/bar')) + self.assertFalse(is_gcs_path('gs:/foo/bar')) + self.assertFalse(is_gcs_path('foo/bar')) + + def test_parse_blob_path_valid(self): + bucket_name, blob_name = parse_blob_path('gs://foo/bar/baz/') + + self.assertEqual('foo', bucket_name) + self.assertEqual('bar/baz/', blob_name) + + def test_parse_blob_path_invalid(self): + # No blob name + self.assertRaises(ValueError, lambda: parse_blob_path('gs://foo')) + self.assertRaises(ValueError, lambda: parse_blob_path('gs://foo/')) + + # Invalid GCS path + self.assertRaises(ValueError, lambda: parse_blob_path('foo')) + self.assertRaises(ValueError, lambda: parse_blob_path('gs:///foo')) \ No newline at end of file diff --git a/component_sdk/python/tests/google/storage/test__download_blob.py b/component_sdk/python/tests/google/storage/test__download_blob.py new file mode 100644 index 00000000000..23e6603188e --- /dev/null +++ b/component_sdk/python/tests/google/storage/test__download_blob.py @@ -0,0 +1,38 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest +import os + +from kfp_component.google.storage import download_blob + +DOWNLOAD_BLOB_MODULE = 'kfp_component.google.storage._download_blob' + +@mock.patch(DOWNLOAD_BLOB_MODULE + '.os') +@mock.patch(DOWNLOAD_BLOB_MODULE + '.open') +@mock.patch(DOWNLOAD_BLOB_MODULE + '.storage.Client') +class DownloadBlobTest(unittest.TestCase): + + def test_download_blob_succeed(self, mock_storage_client, + mock_open, mock_os): + mock_os.path.dirname.return_value = '/foo' + mock_os.path.exists.return_value = False + + download_blob('gs://foo/bar.py', + '/foo/bar.py') + + mock_blob = mock_storage_client().get_bucket().blob() + mock_blob.download_to_file.assert_called_once() + mock_os.makedirs.assert_called_once() \ No newline at end of file