Skip to content

Commit

Permalink
Add some tests for MitreAttackData
Browse files Browse the repository at this point in the history
  • Loading branch information
jondricek committed Oct 30, 2023
1 parent fd7b89e commit 18fd917
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 108 deletions.
29 changes: 29 additions & 0 deletions mitreattack/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,32 @@
# ['mitre-attack', 'mitre-mobile-attack', 'mobile-attack', 'mitre-ics-attack']
MITRE_ATTACK_ID_SOURCE_NAMES = ["mitre-attack", "mobile-attack", "mitre-mobile-attack", "mitre-ics-attack"]
MITRE_ATTACK_DOMAIN_STRINGS = ["mitre-attack", "mitre-mobile-attack", "mitre-ics-attack"]

PLATFORMS_LOOKUP = {
"enterprise-attack": [
"PRE",
"Windows",
"macOS",
"Linux",
"Cloud",
"Office 365",
"Azure AD",
"Google Workspace",
"SaaS",
"IaaS",
"Network",
"Containers",
],
"mobile-attack": ["Android", "iOS"],
"Cloud": ["Office 365", "Azure AD", "Google Workspace", "SaaS", "IaaS"],
"ics-attack": [
"Field Controller/RTU/PLC/IED",
"Safety Instrumented System/Protection Relay",
"Control Server",
"Input/Output Server",
"Windows",
"Human-Machine Interface",
"Engineering Workstation",
"Data Historian",
],
}
5 changes: 3 additions & 2 deletions mitreattack/stix20/MitreAttackData.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""MitreAttackData Library."""

import stix2
from itertools import chain

import stix2
from dateutil import parser
from itertools import chain
from stix2 import Filter
from stix2.utils import get_type_from_id

from mitreattack.stix20.custom_attack_objects import StixObjectFactory


Expand Down
38 changes: 28 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import shutil
from pathlib import Path

import pytest
from loguru import logger
Expand All @@ -8,6 +7,7 @@
from mitreattack.download_stix import download_domains
from mitreattack.navlayers import Layer
from mitreattack.release_info import LATEST_VERSION
from mitreattack.stix20 import MitreAttackData

from .resources.testing_data import example_layer_v3_all, example_layer_v43_dict

Expand Down Expand Up @@ -44,32 +44,50 @@ def stix_file_ics_latest(attack_stix_dir):


@pytest.fixture(scope="session")
def memstore_enterprise_latest(attack_stix_dir):
def memstore_enterprise_latest(stix_file_enterprise_latest):
logger.debug("Loading STIX memstore for Enterprise ATT&CK")
stix_file = f"{attack_stix_dir}/v{LATEST_VERSION}/enterprise-attack.json"
mem_store = MemoryStore()
mem_store.load_from_file(stix_file)
mem_store.load_from_file(stix_file_enterprise_latest)
return mem_store


@pytest.fixture(scope="session")
def memstore_mobile_latest(attack_stix_dir):
def memstore_mobile_latest(stix_file_mobile_latest):
logger.debug("Loading STIX memstore for Mobile ATT&CK")
stix_file = f"{Path.cwd()}/{attack_stix_dir}/v{LATEST_VERSION}/mobile-attack.json"
mem_store = MemoryStore()
mem_store.load_from_file(stix_file)
mem_store.load_from_file(stix_file_mobile_latest)
return mem_store


@pytest.fixture(scope="session")
def memstore_ics_latest(attack_stix_dir):
def memstore_ics_latest(stix_file_ics_latest):
logger.debug("Loading STIX memstore for ICS ATT&CK")
stix_file = f"{attack_stix_dir}/v{LATEST_VERSION}/ics-attack.json"
mem_store = MemoryStore()
mem_store.load_from_file(stix_file)
mem_store.load_from_file(stix_file_ics_latest)
return mem_store


@pytest.fixture(scope="session")
def mitre_attack_data_enterprise(memstore_enterprise_latest):
logger.debug("Loading STIX memstore for Enterprise ATT&CK")
mitre_attack_data = MitreAttackData(src=memstore_enterprise_latest)
return mitre_attack_data


@pytest.fixture(scope="session")
def mitre_attack_data_mobile(memstore_mobile_latest):
logger.debug("Loading STIX memstore for Mobile ATT&CK")
mitre_attack_data = MitreAttackData(src=memstore_mobile_latest)
return mitre_attack_data


@pytest.fixture(scope="session")
def mitre_attack_data_ics(memstore_ics_latest):
logger.debug("Loading STIX memstore for ICS ATT&CK")
mitre_attack_data = MitreAttackData(src=memstore_ics_latest)
return mitre_attack_data


@pytest.fixture()
def layer_v3_all():
layer = Layer()
Expand Down
154 changes: 64 additions & 90 deletions tests/test_mass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from pathlib import Path

from loguru import logger
Expand All @@ -6,133 +7,106 @@
from mitreattack.navlayers import Layer, SVGConfig, ToExcel, ToSvg


def check_svg_generation(layer: Layer, path: Path, resource: MemoryStore, index: int, config: SVGConfig = None):
def check_svg_generation(layer: Layer, path: Path, resource: MemoryStore, config: SVGConfig = None):
t = ToSvg(domain=layer.layer.domain, source="memorystore", resource=resource, config=config)
svg_output = path / f"{index}.svg"
svg_output = path / f"{uuid.uuid4()}.svg"
t.to_svg(layerInit=layer, filepath=str(svg_output))
assert svg_output.exists()


def check_xlsx_generation(layer: Layer, path: Path, resource: MemoryStore, index: int):
def check_xlsx_generation(layer: Layer, path: Path, resource: MemoryStore):
e = ToExcel(domain=layer.layer.domain, source="memorystore", resource=resource)
xlsx_output = path / f"{index}.xlsx"
xlsx_output = path / f"{uuid.uuid4()}.xlsx"
e.to_xlsx(layerInit=layer, filepath=str(xlsx_output))
assert xlsx_output.exists()


def test_showSubtechniques(tmp_path: Path, layer_v3_all: Layer, memstore_enterprise_latest: MemoryStore):
"""Test SVG export: Displaying Subtechniques"""
logger.debug(f"{tmp_path=}")
index = 0
for showSubtechniques in ["all", "expanded", "none"]:
for showHeader in [True, False]:
c = SVGConfig(showSubtechniques=showSubtechniques, showHeader=showHeader)
layer_v3_all.layer.description = f"subs={showSubtechniques},showHeader={showHeader}"
showSubtechniques = "all"
showHeader = True

check_svg_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index, config=c
)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index)
index += 1
c = SVGConfig(showSubtechniques=showSubtechniques, showHeader=showHeader)
layer_v3_all.layer.description = f"subs={showSubtechniques},showHeader={showHeader}"

check_svg_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, config=c)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest)


def test_dimensions(tmp_path: Path, layer_v3_all: Layer, memstore_enterprise_latest: MemoryStore):
"""Test SVG export: dimensions"""
logger.debug(f"{tmp_path=}")
index = 0
for width in [8.5, 11]:
for height in [8.5, 11]:
for headerHeight in [1, 2]:
for unit in ["in", "cm"]:
c = SVGConfig(width=width, height=height, headerHeight=headerHeight, unit=unit)
layer_v3_all.layer.description = f"{width}x{height}{unit}; header={headerHeight}"

check_svg_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index, config=c
)
check_xlsx_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index
)
index += 1
width = 8.5
height = 11
headerHeight = 2
unit = "in"

c = SVGConfig(width=width, height=height, headerHeight=headerHeight, unit=unit)
layer_v3_all.layer.description = f"{width}x{height}{unit}; header={headerHeight}"

check_svg_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, config=c)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest)


def test_legendWidth(tmp_path: Path, layer_v3_all: Layer, memstore_enterprise_latest: MemoryStore):
"""Test SVG export: legend width variations"""
logger.debug(f"{tmp_path=}")
index = 0
for legendWidth in [3, 6]:
for legendHeight in [1, 2]:
for legendX in [2, 4]:
for legendY in [2, 4]:
c = SVGConfig(
legendDocked=False,
legendWidth=legendWidth,
legendHeight=legendHeight,
legendX=legendX,
legendY=legendY,
)
layer_v3_all.layer.description = (
f"undocked legend, {legendWidth}x{legendHeight} at {legendX}x{legendY}"
)

check_svg_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index, config=c
)
check_xlsx_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index
)
index += 1
legendWidth = 3
legendHeight = 1
legendX = 2
legendY = 4

c = SVGConfig(
legendDocked=False,
legendWidth=3,
legendHeight=1,
legendX=2,
legendY=2,
)
layer_v3_all.layer.description = f"undocked legend, {legendWidth}x{legendHeight} at {legendX}x{legendY}"

check_svg_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, config=c)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest)


def test_showFilters(tmp_path: Path, layer_v3_all: Layer, memstore_enterprise_latest: MemoryStore):
"""Test SVG export: customization options"""
logger.debug(f"{tmp_path=}")
index = 0
for showFilters in [True, False]:
for showAbout in [True, False]:
for showLegend in [True, False]:
for showDomain in [True, False]:
c = SVGConfig(
showFilters=showFilters, showAbout=showAbout, showLegend=showLegend, showDomain=showDomain
)
layer_v3_all.layer.description = f"legend={showLegend}, filters={showFilters}, about={showAbout}"

check_svg_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index, config=c
)
check_xlsx_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index
)
index += 1
showFilters = True
showAbout = True
showLegend = True
showDomain = True

c = SVGConfig(showFilters=showFilters, showAbout=showAbout, showLegend=showLegend, showDomain=showDomain)
layer_v3_all.layer.description = f"legend={showLegend}, filters={showFilters}, about={showAbout}"

check_svg_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, config=c)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest)


def test_borders(tmp_path: Path, layer_v3_all: Layer, memstore_enterprise_latest: MemoryStore):
"""Test SVG export: borders"""
logger.debug(f"{tmp_path=}")
index = 0
for border in [0.1, 0.3]:
for tableBorderColor in ["#ddd", "#ffaaaa"]:
c = SVGConfig(border=border, tableBorderColor=tableBorderColor)
layer_v3_all.layer.description = f"border={border}, tableBorderColor={tableBorderColor}"
border = 0.2
tableBorderColor = "#ddd"
c = SVGConfig(border=border, tableBorderColor=tableBorderColor)
layer_v3_all.layer.description = f"border={border}, tableBorderColor={tableBorderColor}"

check_svg_generation(
layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index, config=c
)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index)
index += 1
check_svg_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, config=c)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest)


def test_counts(tmp_path: Path, layer_v3_all: Layer, memstore_enterprise_latest: MemoryStore):
"""Test SVG export: scores/aggregation"""
logger.debug(f"{tmp_path=}")
index = 0
for countUnscored in [True, False]:
for aggregateFunction in ["average", "min", "max", "sum"]:
layer_v3_all.layer.layout.countUnscored = countUnscored
layer_v3_all.layer.layout.aggregateFunction = aggregateFunction
layer_v3_all.layer.description = f"countUnscored={countUnscored}, aggregateFunction={aggregateFunction}"
print(layer_v3_all.layer.description)

check_svg_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest, index=index)
index += 1
countUnscored = True
aggregateFunction = "average"

layer_v3_all.layer.layout.countUnscored = countUnscored
layer_v3_all.layer.layout.aggregateFunction = aggregateFunction
layer_v3_all.layer.description = f"countUnscored={countUnscored}, aggregateFunction={aggregateFunction}"
logger.info(layer_v3_all.layer.description)

check_svg_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest)
check_xlsx_generation(layer=layer_v3_all, path=tmp_path, resource=memstore_enterprise_latest)
Loading

0 comments on commit 18fd917

Please sign in to comment.