Skip to content

Commit

Permalink
Adding Tech Preview badge for Reranker (elastic#202561)
Browse files Browse the repository at this point in the history
## Summary

Adding a `Tech Preview` badge for `reranker` model.


![reranker](https://github.com/user-attachments/assets/eb370f82-5127-4a9c-a00d-9a6d8adca34c)



### Checklist

Check the PR satisfies following conditions. 

Reviewers should verify this PR satisfies this list as well.

- [X] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [X] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
2 people authored and CAWilson94 committed Dec 9, 2024
1 parent 6c8ee9e commit e7fe5e2
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { i18n } from '@kbn/i18n';

export const ELSER_MODEL_ID = '.elser_model_2';
Expand Down Expand Up @@ -308,14 +309,7 @@ export type InferenceServiceSettings =
};
};

export type InferenceAPIConfigResponse = {
// Refers to a deployment id
inference_id: string;
task_type: 'sparse_embedding' | 'text_embedding';
task_settings: {
model?: string;
};
} & InferenceServiceSettings;
export type InferenceAPIConfigResponse = InferenceInferenceEndpointInfo & InferenceServiceSettings;

export function isLocalModel(
model: InferenceServiceSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,60 @@ import { EndpointInfo } from './endpoint_info';

describe('RenderEndpoint component tests', () => {
it('renders the component with inference id', () => {
render(<EndpointInfo inferenceId={'cohere-2'} />);
const mockProvider = {
inference_id: 'cohere-2',
service: 'cohere',
service_settings: {
similarity: 'cosine',
dimensions: 384,
model_id: 'embed-english-light-v3.0',
rate_limit: {
requests_per_minute: 10000,
},
embedding_type: 'byte',
},
task_settings: {},
} as any;

render(<EndpointInfo inferenceId={'cohere-2'} provider={mockProvider} />);

expect(screen.getByText('cohere-2')).toBeInTheDocument();
});

it('renders correctly without model_id in service_settings', () => {
const mockProvider = {
inference_id: 'azure-openai-1',
service: 'azureopenai',
service_settings: {
resource_name: 'resource-xyz',
deployment_id: 'deployment-123',
api_version: 'v1',
},
} as any;

render(<EndpointInfo inferenceId={'azure-openai-1'} provider={mockProvider} />);

expect(screen.getByText('azure-openai-1')).toBeInTheDocument();
});

it('renders with tech preview badge when endpoint is reranker type', () => {
const mockProvider = {
inference_id: 'elastic-rerank',
task_type: 'rerank',
service: 'elasticsearch',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: '.rerank-v1',
},
task_settings: {
return_documents: true,
},
} as any;

render(<EndpointInfo inferenceId={'elastic-rerank'} provider={mockProvider} />);

expect(screen.getByText('elastic-rerank')).toBeInTheDocument();
expect(screen.getByText('TECH PREVIEW')).toBeInTheDocument();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,38 @@

import { EuiBetaBadge, EuiFlexGroup, EuiFlexItem } from '@elastic/eui';
import React from 'react';
import { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import { isEndpointPreconfigured } from '../../../../utils/preconfigured_endpoint_helper';
import * as i18n from './translations';
import { isProviderTechPreview } from '../../../../utils/reranker_helper';

export interface EndpointInfoProps {
inferenceId: string;
provider: InferenceAPIConfigResponse;
}

export const EndpointInfo: React.FC<EndpointInfoProps> = ({ inferenceId }) => (
export const EndpointInfo: React.FC<EndpointInfoProps> = ({ inferenceId, provider }) => (
<EuiFlexGroup justifyContent="spaceBetween">
<EuiFlexItem grow={false}>
<span>
<strong>{inferenceId}</strong>
</span>
<EuiFlexGroup gutterSize="s" alignItems="center">
<EuiFlexItem grow={false}>
<span>
<strong>{inferenceId}</strong>
</span>
</EuiFlexItem>
{isProviderTechPreview(provider) ? (
<EuiFlexItem grow={false}>
<span>
<EuiBetaBadge
label={i18n.TECH_PREVIEW_LABEL}
size="s"
color="subdued"
alignment="middle"
/>
</span>
</EuiFlexItem>
) : null}
</EuiFlexGroup>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<span>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ export const PRECONFIGURED_LABEL = i18n.translate(
defaultMessage: 'PRECONFIGURED',
}
);

export const TECH_PREVIEW_LABEL = i18n.translate(
'xpack.searchInferenceEndpoints.elasticsearch.endpointInfo.techPreview',
{
defaultMessage: 'TECH PREVIEW',
}
);
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ const inferenceEndpoints = [
},
task_settings: {},
},
{
inference_id: 'elastic-rerank',
task_type: 'rerank',
service: 'elasticsearch',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: '.rerank-v1',
},
task_settings: {
return_documents: true,
},
},
] as InferenceAPIConfigResponse[];

jest.mock('../../hooks/use_delete_endpoint', () => ({
Expand All @@ -82,9 +95,10 @@ describe('When the tabular page is loaded', () => {
const rows = screen.getAllByRole('row');
expect(rows[1]).toHaveTextContent('.elser-2-elasticsearch');
expect(rows[2]).toHaveTextContent('.multilingual-e5-small-elasticsearch');
expect(rows[3]).toHaveTextContent('local-model');
expect(rows[4]).toHaveTextContent('my-elser-model-05');
expect(rows[5]).toHaveTextContent('third-party-model');
expect(rows[3]).toHaveTextContent('elastic-rerank');
expect(rows[4]).toHaveTextContent('local-model');
expect(rows[5]).toHaveTextContent('my-elser-model-05');
expect(rows[6]).toHaveTextContent('third-party-model');
});

it('should display all service and model ids in the table', () => {
Expand All @@ -98,13 +112,16 @@ describe('When the tabular page is loaded', () => {
expect(rows[2]).toHaveTextContent('.multilingual-e5-small');

expect(rows[3]).toHaveTextContent('Elasticsearch');
expect(rows[3]).toHaveTextContent('.own_model');
expect(rows[3]).toHaveTextContent('.rerank-v1');

expect(rows[4]).toHaveTextContent('Elasticsearch');
expect(rows[4]).toHaveTextContent('.elser_model_2');
expect(rows[4]).toHaveTextContent('.own_model');

expect(rows[5]).toHaveTextContent('OpenAI');
expect(rows[5]).toHaveTextContent('.own_model');
expect(rows[5]).toHaveTextContent('Elasticsearch');
expect(rows[5]).toHaveTextContent('.elser_model_2');

expect(rows[6]).toHaveTextContent('OpenAI');
expect(rows[6]).toHaveTextContent('.own_model');
});

it('should only disable delete action for preconfigured endpoints', () => {
Expand All @@ -131,4 +148,18 @@ describe('When the tabular page is loaded', () => {
expect(rows[4]).not.toHaveTextContent(preconfigured);
expect(rows[5]).not.toHaveTextContent(preconfigured);
});

it('should show tech preview badge only for reranker-v1 model', () => {
render(<TabularPage inferenceEndpoints={inferenceEndpoints} />);

const techPreview = 'TECH PREVIEW';

const rows = screen.getAllByRole('row');
expect(rows[1]).not.toHaveTextContent(techPreview);
expect(rows[2]).not.toHaveTextContent(techPreview);
expect(rows[3]).toHaveTextContent(techPreview);
expect(rows[4]).not.toHaveTextContent(techPreview);
expect(rows[5]).not.toHaveTextContent(techPreview);
expect(rows[6]).not.toHaveTextContent(techPreview);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ export const TabularPage: React.FC<TabularPageProps> = ({ inferenceEndpoints })
field: 'endpoint',
name: i18n.ENDPOINT,
'data-test-subj': 'endpointCell',
render: (endpoint: string) => {
render: (endpoint: string, additionalInfo: InferenceEndpointUI) => {
if (endpoint) {
return <EndpointInfo inferenceId={endpoint} />;
return <EndpointInfo inferenceId={endpoint} provider={additionalInfo.provider} />;
}

return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { isProviderTechPreview } from './reranker_helper';

describe('Reranker Tech preview badge', () => {
const mockProvider = {
inference_id: 'elastic-rerank',
task_type: 'rerank',
service: 'elasticsearch',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: '.rerank-v1',
},
task_settings: {
return_documents: true,
},
} as any;

it('return true for reranker', () => {
expect(isProviderTechPreview(mockProvider)).toEqual(true);
});

it('return false for other provider', () => {
const otherProviderServiceSettings = {
...mockProvider.service_settings,
model_id: '.elser_model_2',
};
const otherProvider = {
...mockProvider,
task_type: 'sparse_embedding',
service_settings: otherProviderServiceSettings,
} as any;
expect(isProviderTechPreview(otherProvider)).toEqual(false);
});

it('return false for other provider without model_id', () => {
const mockThirdPartyProvider = {
inference_id: 'azure-openai-1',
service: 'azureopenai',
service_settings: {
resource_name: 'resource-xyz',
deployment_id: 'deployment-123',
api_version: 'v1',
},
} as any;
expect(isProviderTechPreview(mockThirdPartyProvider)).toEqual(false);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
export const isProviderTechPreview = (provider: InferenceAPIConfigResponse) => {
if (hasModelId(provider)) {
return provider.task_type === 'rerank' && provider.service_settings?.model_id?.startsWith('.');
}

return false;
};

function hasModelId(
service: InferenceAPIConfigResponse
): service is Extract<InferenceAPIConfigResponse, { service_settings: { model_id: string } }> {
return 'model_id' in service.service_settings;
}

0 comments on commit e7fe5e2

Please sign in to comment.