Skip to content

Commit

Permalink
Merge pull request #121 from mitre-attack/fix/120
Browse files Browse the repository at this point in the history
Check if query results exist
  • Loading branch information
jondricek authored Aug 9, 2023
2 parents b1e5252 + 0d2cd4b commit ed3f579
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
12 changes: 12 additions & 0 deletions mitreattack/stix20/MitreAttackData.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(self, stix_filepath: str):
if not isinstance(stix_filepath, str):
raise TypeError(f"Argument stix_filepath must be of type str, not {type(stix_filepath)}")

self.stix_filepath = stix_filepath

self.src = MemoryStore()
self.src.load_from_file(stix_filepath)

Expand Down Expand Up @@ -302,6 +304,9 @@ def get_objects_by_type(self, stix_type: str, remove_revoked_deprecated=False) -
if remove_revoked_deprecated:
objects = self.remove_revoked_deprecated(objects)

if not objects:
return []

# since ATT&CK has custom objects, we need to reconstruct the query results
return [StixObjectFactory(o) for o in objects]

Expand Down Expand Up @@ -520,6 +525,10 @@ def get_object_by_stix_id(self, stix_id: str) -> object:
the STIX Domain Object specified by the STIX ID
"""
object = self.src.get(stix_id)

if not object:
raise ValueError(f"{stix_id} not found in {self.stix_filepath}")

return StixObjectFactory(object)

def get_object_by_attack_id(self, attack_id: str, stix_type: str) -> object:
Expand Down Expand Up @@ -586,6 +595,9 @@ def get_objects_by_name(self, name: str, stix_type: str) -> list:
filter = [Filter("type", "=", stix_type), Filter("name", "=", name)]
objects = self.src.query(filter)

if not objects:
return []

# since ATT&CK has custom objects, we need to reconstruct the query results
return [StixObjectFactory(o) for o in objects]

Expand Down
6 changes: 4 additions & 2 deletions mitreattack/stix20/custom_attack_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def StixObjectFactory(data: dict) -> object:
"x-mitre-data-component": DataComponent,
}

if "type" in data and data["type"] in stix_type_to_custom_class:
return stix_type_to_custom_class[data["type"]](**data, allow_custom=True)
stix_type = data.get("type")

if data and stix_type in stix_type_to_custom_class:
return stix_type_to_custom_class[stix_type](**data, allow_custom=True)
return data


Expand Down

0 comments on commit ed3f579

Please sign in to comment.