Skip to content

Commit

Permalink
add reinit to why.init()
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-rogers committed Aug 28, 2024
1 parent 05bd6f1 commit faf0a78
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/tests/api/writer/test_whylabs_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_whylabs_writer_throttle_retry():
MODEL_ID = "XXX"
uri = f"{ENDPOINT}/v0/organizations/{ORG_ID}/log/async/{MODEL_ID}"
httpretty.register_uri(httpretty.POST, uri, status=429) # Fake WhyLabs that throttles
why.init(force_local=True)
why.init(reinit=True, force_local=True)
data = {"col1": 1, "col2": "foo"}
result = why.log(data)
writer = WhyLabsWriter(dataset_id=MODEL_ID)
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_log_batch():
def test_whylabs_writer():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
schema = DatasetSchema()
data = {"col1": 1, "col2": "foo"}
trace_id = str(uuid4())
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_whylabs_writer():
def test_whylabs_writer_segmented(zipped: bool):
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
schema = DatasetSchema(segments=segment_on_column("col1"))
data = {"col1": [1, 2, 1, 3, 2, 2], "col2": ["foo", "bar", "wat", "foo", "baz", "wat"]}
df = pd.DataFrame(data)
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_whylabs_writer_segmented(zipped: bool):
def test_whylabs_writer_reference(segmented: bool, zipped: bool):
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
if segmented:
schema = DatasetSchema(segments=segment_on_column("col1"))
else:
Expand Down Expand Up @@ -567,7 +567,7 @@ def test_estimation_result():
def test_transactions():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
schema = DatasetSchema()
data = {"col1": 1, "col2": "foo"}
trace_id = str(uuid4())
Expand Down Expand Up @@ -600,7 +600,7 @@ def test_transactions():
def test_transaction_aborted():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
data = {"col1": 1, "col2": "foo"}
trace_id = str(uuid4())
result = why.log(data, trace_id=trace_id)
Expand Down Expand Up @@ -629,7 +629,7 @@ def test_transaction_aborted():
def test_transaction_context():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
Expand Down Expand Up @@ -668,7 +668,7 @@ def test_transaction_context():
def test_old_transaction_context():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_old_transaction_context():

@pytest.mark.load
def test_transaction_context_aborted():
why.init(force_local=True)
why.init(reinit=True, force_local=True)
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
pdfs = np.array_split(df, 7)
Expand All @@ -726,7 +726,7 @@ def test_transaction_context_aborted():
def test_transaction_segmented():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
schema = DatasetSchema(segments=segment_on_column("Gender"))
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
data = pd.read_csv(csv_url)
Expand Down Expand Up @@ -767,7 +767,7 @@ def test_transaction_segmented():
def test_transaction_distributed():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
why.init(reinit=True, force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
Expand Down

0 comments on commit faf0a78

Please sign in to comment.