diff --git a/.github/workflows/ci-minimal-dependency-check.yml b/.github/workflows/ci-minimal-dependency-check.yml new file mode 100644 index 00000000..d6e75471 --- /dev/null +++ b/.github/workflows/ci-minimal-dependency-check.yml @@ -0,0 +1,33 @@ +name: Minimal dependency check + +on: + push: + branches: [main, "release/*"] + pull_request: + branches: [main, "release/*"] + +defaults: + run: + shell: bash + +jobs: + pytester: + runs-on: ubuntu-latest + + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9 + + - name: Install LitServe + run: | + pip --version + pip install . psutil -U -q + pip list + + - name: Tests + run: python tests/minimal_run.py diff --git a/src/litserve/connector.py b/src/litserve/connector.py index dbf788ca..c771c50f 100644 --- a/src/litserve/connector.py +++ b/src/litserve/connector.py @@ -72,12 +72,17 @@ def _auto_device_count(self, accelerator) -> int: @staticmethod def _choose_gpu_accelerator_backend(): - import torch - if check_cuda_with_nvidia_smi() > 0: return "cuda" - if torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): - return "mps" + + try: + import torch + + if torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): + return "mps" + except ImportError: + return None + return None diff --git a/tests/minimal_run.py b/tests/minimal_run.py new file mode 100644 index 00000000..33c46338 --- /dev/null +++ b/tests/minimal_run.py @@ -0,0 +1,46 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import subprocess +import psutil +import time +import json +import urllib.request + + +def main(): + process = subprocess.Popen( + ["python", "tests/simple_server.py"], + ) + print("Waiting for server to start...") + time.sleep(5) + try: + url = "http://127.0.0.1:8000/predict" + data = json.dumps({"input": 4.0}).encode("utf-8") + headers = {"Content-Type": "application/json"} + request = urllib.request.Request(url, data=data, headers=headers, method="POST") + response = urllib.request.urlopen(request) + status_code = response.getcode() + assert status_code == 200 + except Exception: + raise + + finally: + parent = psutil.Process(process.pid) + for child in parent.children(recursive=True): + child.kill() + process.kill() + + +if __name__ == "__main__": + main()