Skip to content

Commit

Permalink
[Rule Migration] Improve rule translation prompts and processes (#204021
Browse files Browse the repository at this point in the history
)

## Summary

This PR performs multiple changes that all focuses on improving the
quality of the results returned when we translate rules that do not
match with a prebuilt rule and both with/without related integrations.

Changes include:

- Add a filter_index_patterns node, to always ensure `logs-*` is removed
with our `[indexPattern:logs-*]` value, which is similar to how we
detect missing lookups and macros.
- Split `translate_rule` into another `ecs_mapping` node, trying to
ensure translation focuses on changing SPL to ESQL without any focus on
actual field names, while the other node focuses only on the ESQL query
and changing field names.
- The summary now added in the comments have 1 for the translation and
one for the ECS mapping.
- Add default rule batch size `15` with PR comment/question.
- Ensure we only return one integration related rather than an array for
now, to make ESQL more focused on one related integration.
- New prompt to filter out one or more integrations from the returned
RAG; similar to how its done for rules RAG results already.
  • Loading branch information
P1llus authored Dec 12, 2024
1 parent 0dabc52 commit 0a7262d
Show file tree
Hide file tree
Showing 23 changed files with 560 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ export const ElasticRule = z.object({
*/
prebuilt_rule_id: NonEmptyString.optional(),
/**
* The Elastic integration IDs related to the rule.
* The Elastic integration ID found to be most relevant to the splunk rule.
*/
integration_ids: z.array(z.string()).optional(),
integration_id: z.string().optional(),
/**
* The Elastic rule id installed as a result.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,9 @@ components:
prebuilt_rule_id:
description: The Elastic prebuilt rule id matched.
$ref: '../../../common/api/model/primitives.schema.yaml#/components/schemas/NonEmptyString'
integration_ids:
type: array
items:
type: string
description: The Elastic integration IDs related to the rule.
integration_id:
type: string
description: The Elastic integration ID found to be most relevant to the splunk rule.
id:
description: The Elastic rule id installed as a result.
$ref: '../../../common/api/model/primitives.schema.yaml#/components/schemas/NonEmptyString'
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export const ruleMigrationsFieldMap: FieldMap<SchemaFieldMapKeys<Omit<RuleMigrat
'original_rule.annotations.mitre_attack': { type: 'keyword', array: true, required: false },
elastic_rule: { type: 'nested', required: false },
'elastic_rule.title': { type: 'text', required: true, fields: { keyword: { type: 'keyword' } } },
'elastic_rule.integration_ids': { type: 'keyword', array: true, required: false },
'elastic_rule.integration_id': { type: 'keyword', required: false },
'elastic_rule.query': { type: 'text', required: true },
'elastic_rule.query_language': { type: 'keyword', required: true },
'elastic_rule.description': { type: 'text', required: false },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

import { END, START, StateGraph } from '@langchain/langgraph';
import { SiemMigrationRuleTranslationResult } from '../../../../../../common/siem_migrations/constants';
import { getCreateSemanticQueryNode } from './nodes/create_semantic_query';
import { getMatchPrebuiltRuleNode } from './nodes/match_prebuilt_rule';
import { getProcessQueryNode } from './nodes/process_query';
Expand All @@ -23,9 +24,11 @@ export function getRuleMigrationAgent({
}: MigrateRuleGraphParams) {
const matchPrebuiltRuleNode = getMatchPrebuiltRuleNode({
model,
logger,
ruleMigrationsRetriever,
});
const translationSubGraph = getTranslateRuleGraph({
model,
inferenceClient,
ruleMigrationsRetriever,
connectorId,
Expand All @@ -41,23 +44,26 @@ export function getRuleMigrationAgent({
.addNode('matchPrebuiltRule', matchPrebuiltRuleNode)
.addNode('translationSubGraph', translationSubGraph)
// Edges
.addEdge(START, 'processQuery')
.addEdge('processQuery', 'createSemanticQuery')
.addEdge(START, 'createSemanticQuery')
.addEdge('createSemanticQuery', 'matchPrebuiltRule')
.addConditionalEdges('matchPrebuiltRule', matchedPrebuiltRuleConditional, [
'translationSubGraph',
END,
])
.addConditionalEdges('matchPrebuiltRule', matchedPrebuiltRuleConditional, ['processQuery', END])
.addEdge('processQuery', 'translationSubGraph')
.addEdge('translationSubGraph', END);

const graph = siemMigrationAgentGraph.compile();
graph.name = 'Rule Migration Graph'; // Customizes the name displayed in LangSmith
return graph;
}

/*
* If the original splunk rule has no prebuilt rule match, we will start processing the query, unless it is related to input/outputlookups.
*/
const matchedPrebuiltRuleConditional = (state: MigrateRuleState) => {
if (state.elastic_rule?.prebuilt_rule_id) {
return END;
}
return 'translationSubGraph';
if (state.translation_result === SiemMigrationRuleTranslationResult.UNTRANSLATABLE) {
return END;
}
return 'processQuery';
};
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Go through the relevant title, description and data sources from the above query
- Include keywords that are relevant to the use case.
- Add related keywords you detected from the above query, like one or more vendor, product, cloud provider, OS platform etc.
- Always reply with a JSON object with the key "semantic_query" and the value as the semantic search query inside three backticks as shown in the below example.
- If the related query focuses on Endpoint datamodel, make sure that "endpoint", "security" keywords are included.
</guidelines>
<example_response>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import type { Logger } from '@kbn/core/server';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import { SiemMigrationRuleTranslationResult } from '../../../../../../../../common/siem_migrations/constants';
import type { RuleMigrationsRetriever } from '../../../retrievers';
Expand All @@ -14,16 +15,20 @@ import { MATCH_PREBUILT_RULE_PROMPT } from './prompts';

interface GetMatchPrebuiltRuleNodeParams {
model: ChatModel;
logger: Logger;
ruleMigrationsRetriever: RuleMigrationsRetriever;
}

interface GetMatchedRuleResponse {
match: string;
}

export const getMatchPrebuiltRuleNode =
({ model, ruleMigrationsRetriever }: GetMatchPrebuiltRuleNodeParams): GraphNode =>
async (state) => {
export const getMatchPrebuiltRuleNode = ({
model,
ruleMigrationsRetriever,
logger,
}: GetMatchPrebuiltRuleNodeParams): GraphNode => {
return async (state) => {
const query = state.semantic_query;
const techniqueIds = state.original_rule.annotations?.mitre_attack || [];
const prebuiltRules = await ruleMigrationsRetriever.prebuiltRules.getRules(
Expand All @@ -32,7 +37,7 @@ export const getMatchPrebuiltRuleNode =
);

const outputParser = new JsonOutputParser();
const matchPrebuiltRule = MATCH_PREBUILT_RULE_PROMPT.pipe(model).pipe(outputParser);
const mostRelevantRule = MATCH_PREBUILT_RULE_PROMPT.pipe(model).pipe(outputParser);

const elasticSecurityRules = prebuiltRules.map((rule) => {
return {
Expand All @@ -41,9 +46,17 @@ export const getMatchPrebuiltRuleNode =
};
});

const response = (await matchPrebuiltRule.invoke({
const splunkRule = {
title: state.original_rule.title,
description: state.original_rule.description,
};

/*
* Takes the most relevant rule from the array of rule(s) returned by the semantic query, returns either the most relevant or none.
*/
const response = (await mostRelevantRule.invoke({
rules: JSON.stringify(elasticSecurityRules, null, 2),
ruleTitle: state.original_rule.title,
splunk_rule: JSON.stringify(splunkRule, null, 2),
})) as GetMatchedRuleResponse;
if (response.match) {
const matchedRule = prebuiltRules.find((r) => r.name === response.match);
Expand All @@ -59,5 +72,16 @@ export const getMatchPrebuiltRuleNode =
};
}
}
const lookupTypes = ['inputlookup', 'outputlookup'];
if (
state.original_rule?.query &&
lookupTypes.some((type) => state.original_rule.query.includes(type))
) {
logger.debug(
`Rule: ${state.original_rule?.title} did not match any prebuilt rule, but contains inputlookup, dropping`
);
return { translation_result: SiemMigrationRuleTranslationResult.UNTRANSLATABLE };
}
return {};
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,22 @@ Here are some context for you to reference for your task, read it carefully as y
[
'human',
`See the below description of the relevant splunk rule and try to match it with any of the elastic detection rules with similar names.
<splunk_rule_name>
{ruleTitle}
</splunk_rule_name>
<splunk_rule>
{splunk_rule}
</splunk_rule>
<guidelines>
- Always reply with a JSON object with the key "match" and the value being the most relevant matched elastic detection rule name. Do not reply with anything else.
- Only reply with exact matches, if you are unsure or do not find a very confident match, always reply with an empty string value in the match key, do not guess or reply with anything else.
- If there is one Elastic rule in the list that covers the same threat, set the name of the matching rule as a value of the match key. Do not reply with anything else.
- If there are multiple rules in the list that cover the same threat, answer with the most specific of them, for example: "Linux User Account Creation" is more specific than "User Account Creation".
- If there is one Elastic rule in the list that covers the same usecase, set the name of the matching rule as a value of the match key. Do not reply with anything else.
- If there are multiple rules in the list that cover the same usecase, answer with the most specific of them, for example: "Linux User Account Creation" is more specific than "User Account Creation".
</guidelines>
<example_response>
U: <splunk_rule_name>
Linux Auditd Add User Account Type
</splunk_rule_name>
U: <splunk_rule>
Title: Linux Auditd Add User Account Type
Description: The following analytic detects the suspicious add user account type.
</splunk_rule>
A: Please find the match JSON object below:
\`\`\`json
{{"match": "Linux User Account Creation"}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import type {
OriginalRule,
RuleMigration,
} from '../../../../../../common/siem_migrations/model/rule_migration.gen';
import type { Integration } from '../../types';

export const migrateRuleState = Annotation.Root({
messages: Annotation<BaseMessage[]>({
Expand All @@ -32,10 +31,6 @@ export const migrateRuleState = Annotation.Root({
reducer: (current, value) => value ?? current,
default: () => '',
}),
integrations: Annotation<Integration[]>({
reducer: (current, value) => value ?? current,
default: () => [],
}),
translation_result: Annotation<SiemMigrationRuleTranslationResult>(),
comments: Annotation<RuleMigration['comments']>({
reducer: (current, value) => (value ? (current ?? []).concat(value) : current),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import { END, START, StateGraph } from '@langchain/langgraph';
import { isEmpty } from 'lodash/fp';
import { SiemMigrationRuleTranslationResult } from '../../../../../../../../common/siem_migrations/constants';
import { getEcsMappingNode } from './nodes/ecs_mapping';
import { getFilterIndexPatternsNode } from './nodes/filter_index_patterns';
import { getFixQueryErrorsNode } from './nodes/fix_query_errors';
import { getRetrieveIntegrationsNode } from './nodes/retrieve_integrations';
import { getTranslateRuleNode } from './nodes/translate_rule';
Expand All @@ -19,6 +21,7 @@ import type { TranslateRuleGraphParams, TranslateRuleState } from './types';
const MAX_VALIDATION_ITERATIONS = 3;

export function getTranslateRuleGraph({
model,
inferenceClient,
connectorId,
ruleMigrationsRetriever,
Expand All @@ -31,20 +34,30 @@ export function getTranslateRuleGraph({
});
const validationNode = getValidationNode({ logger });
const fixQueryErrorsNode = getFixQueryErrorsNode({ inferenceClient, connectorId, logger });
const retrieveIntegrationsNode = getRetrieveIntegrationsNode({ ruleMigrationsRetriever });
const retrieveIntegrationsNode = getRetrieveIntegrationsNode({ model, ruleMigrationsRetriever });
const ecsMappingNode = getEcsMappingNode({ inferenceClient, connectorId, logger });
const filterIndexPatternsNode = getFilterIndexPatternsNode({ logger });

const translateRuleGraph = new StateGraph(translateRuleState)
// Nodes
.addNode('translateRule', translateRuleNode)
.addNode('validation', validationNode)
.addNode('fixQueryErrors', fixQueryErrorsNode)
.addNode('retrieveIntegrations', retrieveIntegrationsNode)
.addNode('ecsMapping', ecsMappingNode)
.addNode('filterIndexPatterns', filterIndexPatternsNode)
// Edges
.addEdge(START, 'retrieveIntegrations')
.addEdge('retrieveIntegrations', 'translateRule')
.addEdge('translateRule', 'validation')
.addEdge('fixQueryErrors', 'validation')
.addConditionalEdges('validation', validationRouter, ['fixQueryErrors', END]);
.addEdge('ecsMapping', 'validation')
.addConditionalEdges('validation', validationRouter, [
'fixQueryErrors',
'ecsMapping',
'filterIndexPatterns',
])
.addEdge('filterIndexPatterns', END);

const graph = translateRuleGraph.compile();
graph.name = 'Translate Rule Graph';
Expand All @@ -59,6 +72,9 @@ const validationRouter = (state: TranslateRuleState) => {
if (!isEmpty(state.validation_errors?.esql_errors)) {
return 'fixQueryErrors';
}
if (!state.translation_finalized) {
return 'ecsMapping';
}
}
return END;
return 'filterIndexPatterns';
};
Loading

0 comments on commit 0a7262d

Please sign in to comment.