Skip to content

Commit

Permalink
fix dataset list key error (#279)
Browse files Browse the repository at this point in the history
* fix dataset list key error
  • Loading branch information
lfayoux authored Aug 15, 2023
1 parent f09dd6e commit 669ec6e
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 4.20.1
- [#279] (https://github.com/cohere-ai/cohere-python/pull/279)
- Fix dataset listing key error

## 4.20.0

- [#276] (https://github.com/cohere-ai/cohere-python/pull/276)
Expand Down
2 changes: 1 addition & 1 deletion cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def list_datasets(self, dataset_type: str = None, limit: int = None, offset: int
response = self._request(f"{cohere.DATASET_URL}", method="GET", params=param_dict)
return [
Dataset.from_dict({"meta": response.get("meta"), **r}, wait_fn=self.wait_for_dataset)
for r in response["datasets"]
for r in (response.get("datasets") or [])
]

def delete_dataset(self, id: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ async def list_datasets(
response = await self._request(f"{cohere.DATASET_URL}", method="GET", params=param_dict)
return [
AsyncDataset.from_dict({"meta": response.get("meta"), **r}, wait_fn=self.wait_for_dataset)
for r in response["datasets"]
for r in (response.get("datasets") or [])
]

async def delete_dataset(self, id: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "4.20.0"
version = "4.20.1"
description = ""
authors = ["Cohere"]
readme = "README.md"
Expand Down
7 changes: 7 additions & 0 deletions tests/async/test_async_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import io
import json
import os
import time
from typing import Optional

Expand All @@ -9,8 +10,11 @@
from cohere import AsyncClient
from cohere.responses import AsyncDataset

IN_CI = os.getenv("CI", "").lower() in ["true", "1"]


@pytest.mark.asyncio
@pytest.mark.skipif(IN_CI, reason="can timeout during high load")
async def test_async_create_dataset(async_client: AsyncClient):
dataset = await async_client.create_dataset(
name="ci-test",
Expand All @@ -35,6 +39,7 @@ async def test_async_create_dataset(async_client: AsyncClient):


@pytest.mark.asyncio
@pytest.mark.skipif(IN_CI, reason="can timeout during high load")
async def test_async_create_invalid_dataset(async_client: AsyncClient):
dataset = await async_client.create_dataset(
name="ci-test",
Expand All @@ -59,13 +64,15 @@ async def test_async_create_invalid_dataset(async_client: AsyncClient):


@pytest.mark.asyncio
@pytest.mark.skipif(IN_CI, reason="can timeout during high load")
async def test_async_get_dataset(async_client: AsyncClient):
datasets = await async_client.list_datasets()
dataset = await async_client.get_dataset(datasets[0].id)
check_result(dataset)


@pytest.mark.asyncio
@pytest.mark.skipif(IN_CI, reason="can timeout during high load")
async def test_async_list_dataset(async_client: AsyncClient):
datasets = await async_client.list_datasets()
assert len(datasets) > 0
Expand Down
6 changes: 5 additions & 1 deletion tests/sync/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import unittest
from typing import Optional

from utils import get_api_key
from utils import get_api_key, in_ci

import cohere
from cohere.responses import Dataset


class TestDataset(unittest.TestCase):
@unittest.skipIf(in_ci(), "can sometimes fail due to duration variation")
def test_create_dataset(self):
co = self.create_co()
dataset = co.create_dataset(
Expand All @@ -33,6 +34,7 @@ def test_create_dataset(self):

self.check_result(dataset, status="validated")

@unittest.skipIf(in_ci(), "can sometimes fail due to duration variation")
def test_create_invalid_dataset(self):
co = self.create_co()
dataset = co.create_dataset(
Expand All @@ -55,12 +57,14 @@ def test_create_invalid_dataset(self):

self.check_result(dataset, status="failed")

@unittest.skipIf(in_ci(), "can sometimes fail due to duration variation")
def test_get_dataset(self):
co = self.create_co()
datasets = co.list_datasets()
dataset = co.get_dataset(datasets[0].id)
self.check_result(dataset)

@unittest.skipIf(in_ci(), "can sometimes fail due to duration variation")
def test_list_dataset(self):
co = self.create_co()
datasets = co.list_datasets()
Expand Down

0 comments on commit 669ec6e

Please sign in to comment.