Skip to content

Commit

Permalink
fix: write one statement per CSV row
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Davie <mldavie@amazon.com>
  • Loading branch information
michaeldavie-amzn committed Sep 24, 2024
1 parent 25dbc7a commit 15a3947
Showing 1 changed file with 26 additions and 47 deletions.
73 changes: 26 additions & 47 deletions trestle/tasks/oscal_catalog_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,11 @@ def convert_control_id(control_id: str) -> str:

def convert_smt_id(smt_id: str) -> str:
"""Convert smt id."""
parts = smt_id.split('_smt')
seg1 = convert_control_id(parts[0])
seg2 = ''
if len(parts) == 2:
seg2 = parts[1]
if '.' in seg2:
seg2 = seg2.replace('.', '(')
seg2 = seg2 + ')'
rval = f'{seg1}{seg2}'
return rval
parts = smt_id.split('_')
control_id = convert_control_id(parts[0])
sub_ids = parts[1].split('.')[1:]
sub_id = ''.join(f'({s})' for s in sub_ids)
return f'{control_id}{sub_id}'


class CsvHelper:
Expand Down Expand Up @@ -118,6 +113,15 @@ def _init_control_parent_map(self, recurse=True) -> None:
raise RuntimeError('{parent} duplicate?')
self._control_parent_map[parent] = control

def derive_text(self, control: Control, part: Part) -> Optional[str]:
"""Derive control text."""
rval = None
if part.prose:
id_ = self._derive_id(part.id)
text = self._resolve_parms(control, part.prose)
rval = join_str(id_, text)
return rval

def get_parent_control(self, ctl_id: str) -> Control:
"""Return parent Control of child Control.id, if any."""
return self._control_parent_map.get(ctl_id)
Expand All @@ -141,16 +145,6 @@ def get_statement_text_for_control(self, control: Control) -> Optional[str]:
statement_text = self._withdrawn(control)
return statement_text

def get_statement_text_for_part(self, control: Control, part: Part) -> Optional[str]:
"""Get statement text for part."""
statement_text = self._derive_text(control, part)
if part.parts:
for subpart in part.parts:
if '_smt' in subpart.id:
partial_text = self._derive_text(control, subpart)
statement_text = join_str(statement_text, partial_text)
return statement_text

def _withdrawn(self, control: Control) -> Optional[str]:
"""Check if withdrawn."""
rval = None
Expand Down Expand Up @@ -192,15 +186,6 @@ def _href_to_control(self, href: str) -> str:
rval = href.replace('#', '').upper()
return rval

def _derive_text(self, control: Control, part: Part) -> Optional[str]:
"""Derive control text."""
rval = None
if part.prose:
id_ = self._derive_id(part.id)
text = self._resolve_parms(control, part.prose)
rval = join_str(id_, text)
return rval

def _derive_id(self, id_: str) -> str:
"""Derive control text sub-part id."""
rval = None
Expand Down Expand Up @@ -319,30 +304,24 @@ def _get_content_by_statement(self) -> List:
self.add(row)
return self.rows

def _add_subparts_by_statement(self, control: Control, part: Part) -> None:
"""Add subparts by statement."""
def _add_statements_recursively(self, control: Control, part: Part) -> None:
"""Add parts and subparts recursively."""
catalog_helper = self.catalog_helper
control_id = convert_control_id(control.id)
for subpart in part.parts:
if '_smt' in subpart.id:
statement_text = catalog_helper.get_statement_text_for_part(control, subpart)
row = [control_id, control.title, convert_smt_id(subpart.id), statement_text]
self.add(row)

if part.id and '_smt' in part.id:
statement_text = catalog_helper.derive_text(control, part)
row = [control_id, control.title, convert_smt_id(part.id), statement_text]
self.add(row)

if part.parts:
for subpart in part.parts:
self._add_statements_recursively(control, subpart)

def _add_parts_by_statement(self, control: Control) -> None:
"""Add parts by statement."""
catalog_helper = self.catalog_helper
control_id = convert_control_id(control.id)
for part in control.parts:
if part.id:
if '_smt' not in part.id:
continue
if part.parts:
self._add_subparts_by_statement(control, part)
else:
statement_text = catalog_helper.get_statement_text_for_part(control, part)
row = [control_id, control.title, convert_smt_id(part.id), statement_text]
self.add(row)
self._add_statements_recursively(control, part)

def _get_content_by_control(self) -> List:
"""Get content by statement."""
Expand Down

0 comments on commit 15a3947

Please sign in to comment.