Skip to content

Commit

Permalink
rfb gold now has swt-type column
Browse files Browse the repository at this point in the history
  • Loading branch information
keighrim committed Jul 29, 2024
1 parent 2c7ca4c commit 50846bc
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions role-filler-binding/process.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#!/usr/bin/env python3

# ====================================|
# Imports
# ====================================|
import csv
import json
import os
Expand All @@ -11,7 +8,6 @@
from typing import Dict, List, Union, Tuple
from clams_utils.aapb import guidhandler

# ====================================|

def build_csv_string(annos: Dict) -> Tuple[bool, str]:
"""Convert a dictionary to a csv string.
Expand Down Expand Up @@ -40,13 +36,18 @@ def build_csv_string(annos: Dict) -> Tuple[bool, str]:

return is_skipped, out

def process_golds(fname: Union[str, os.PathLike]) -> List[Tuple[str, str, bool, str]]:

def process_golds(fname: Union[str, os.PathLike]) -> List[Tuple[str, str, str, bool, str]]:
"""Process a directory of gold files
### params
+ fname := input raw annotation file name
### returns
a list of tuples, each representing a line of the csv
"""
swt_type_dict = {
'231117': 'credits'
}
swt_type = swt_type_dict[fname.split('-')[0]]
out = []
with open(fname, "r", encoding='utf-8') as f:
fp_golds = json.load(f)
Expand All @@ -55,14 +56,14 @@ def process_golds(fname: Union[str, os.PathLike]) -> List[Tuple[str, str, bool,
csv_line = (
guid,
frame,
build_csv_string(annotations)[0],
build_csv_string(annotations)[1],
swt_type,
*build_csv_string(annotations)
)
out.append(csv_line)
return out


def write_csv(csvs: List[Tuple[str, str, str]], outfname: Union[str, os.PathLike]):
def write_csv(csvs: List[Tuple[str, str, str, bool, str]], outfname: Union[str, os.PathLike]):
"""Write csv lines to file
### params
Expand All @@ -73,7 +74,7 @@ def write_csv(csvs: List[Tuple[str, str, str]], outfname: Union[str, os.PathLike
"""
with open(outfname, "w", encoding='utf-8') as f:
writer = csv.writer(f, lineterminator='\n')
writer.writerow(["GUID", "FRAME", "SKIPPED", "ANNOTATIONS"])
writer.writerow(["GUID", "FRAME", "SWT-TYPE", "SKIPPED", "ANNOTATIONS"])
writer.writerows(csvs)


Expand All @@ -86,9 +87,10 @@ def main():
os.makedirs(out_dir)
for root, _, anns in os.walk(source_gold_anns_dir):
for ann in anns:
write_csv(process_golds(os.path.join(root, ann)),
os.path.join(out_dir, guidhandler.get_aapb_guid_from(ann) + "-gold.csv")
)
write_csv(
process_golds(os.path.join(root, ann)),
os.path.join(out_dir, guidhandler.get_aapb_guid_from(ann) + "-gold.csv")
)

if __name__ == "__main__":
main()

0 comments on commit 50846bc

Please sign in to comment.