Skip to content

Commit

Permalink
shared code logic library
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanieshong committed Jan 31, 2023
1 parent 9f26d16 commit b516b60
Show file tree
Hide file tree
Showing 12 changed files with 1,525 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pipeline_logic/v2/shared-logic/src/setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[pep8]
max-line-length = 120

[flake8]
max-line-length = 120
19 changes: 19 additions & 0 deletions pipeline_logic/v2/shared-logic/src/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
"""Library setup script."""

import os
from setuptools import find_packages, setup

setup(
name=os.environ['PKG_NAME'],
version=os.environ['PKG_VERSION'],

description='My Python library project',

author="UNITE",

packages=find_packages(exclude=['contrib', 'docs', 'test']),

# Please instead specify your dependencies in conda_recipe/meta.yml
install_requires=[],
)
Empty file.
46 changes: 46 additions & 0 deletions pipeline_logic/v2/shared-logic/src/source_cdm_utils/clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pyspark.sql import functions as F


empty_payer_plan_cols = [
"payer_source_concept_id",
"plan_concept_id",
"plan_source_concept_id",
"sponsor_concept_id",
"sponsor_source_concept_id",
"stop_reason_concept_id",
"stop_reason_source_concept_id"
]


def conceptualize(domain, df, concept):
concept = concept.select("concept_id", "concept_name")

concept_id_columns = [col for col in df.columns if col.endswith("_concept_id")]
# _concept_id columns should be Integer
df = df.select(
*[col for col in df.columns if col not in concept_id_columns] +
[F.col(col).cast("integer").alias(col) for col in concept_id_columns]
)

for col in concept_id_columns:
new_df = df

if col in empty_payer_plan_cols:
# Create an empty *_concept_name column for these cols to prevent an OOM during the join while keeping the schema consistent
new_df = new_df.withColumn("concept_name", F.lit(None).cast("string"))
else:
new_df = new_df.join(concept, [new_df[col] == concept["concept_id"]], "left_outer").drop("concept_id")

concept_type = col[:col.index("_concept_id")]
new_df = new_df.withColumnRenamed("concept_name", concept_type+"_concept_name")

df = new_df

# occurs in observation, measurement
if "value_as_number" in df.columns:
df = df.withColumn("value_as_number", df.value_as_number.cast("double"))

if "person_id" in df.columns:
df = df.filter(df.person_id.isNotNull())

return df
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pyspark.sql import functions as F, types as T
from pyspark.sql.window import Window as W


def new_duplicate_rows_with_collision_bits(omop_domain, lookup_df, ctx, pk_col, full_hash_col):

# Extract all duplicate rows from domain table
# Keep two columns: 51 bit hash (which caused collision) and full hash (to differentiate collisions)
w = W.partitionBy(pk_col)
duplicates_df = omop_domain.dataframe().select('*', F.count(pk_col).over(w).alias('dupeCount'))\
.where('dupeCount > 1')\
.drop('dupeCount')
duplicates_df = duplicates_df.select(pk_col, full_hash_col)

if ctx.is_incremental:
# Count how many rows in the lookup table exist for the collided hash value
cache = lookup_df.dataframe('previous', schema=T.StructType([
T.StructField(pk_col, T.LongType(), True),
T.StructField(full_hash_col, T.StringType(), True),
T.StructField("collision_bits", T.IntegerType(), True)
]))
cache_count = cache.groupby(pk_col).count()

# Keep only the rows in duplicates_df that are not currently in lookup table
cond = [pk_col, full_hash_col]
duplicates_df = duplicates_df.join(cache, cond, 'left_anti')

# Create counter for rows in duplicates_df
# Subtract 1 because the default collision resolution bit value is 0
w2 = W.partitionBy(pk_col).orderBy(pk_col)
duplicates_df = duplicates_df.withColumn('row_num', F.row_number().over(w2))
duplicates_df = duplicates_df.withColumn('row_num', (F.col('row_num') - 1))

# If there are already entries in the lookup table for the given primary key,
# then add the number of existing entries to the row number counter
if ctx.is_incremental:
duplicates_df = duplicates_df.join(cache_count, pk_col, 'left')
duplicates_df = duplicates_df.fillna(0, subset=['count'])
duplicates_df = duplicates_df.withColumn('row_num', (F.col('row_num') + F.col('count').cast(T.IntegerType())))

duplicates_df = duplicates_df.withColumnRenamed('row_num', 'collision_bits')

# Remove 'count' column for incremental transforms:
duplicates_df = duplicates_df.select(pk_col, full_hash_col, 'collision_bits')

return duplicates_df
36 changes: 36 additions & 0 deletions pipeline_logic/v2/shared-logic/src/source_cdm_utils/manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from datetime import datetime
from pyspark.sql import functions as F
from source_cdm_utils import schema


def manifest(ctx, manifest_schema, manifest_df, site_id_df, data_partner_ids, omop_vocab, control_map):
# Handle empty manifest
if manifest_df.count() == 0:
schema_struct = schema.schema_dict_to_struct(manifest_schema, False)
data = [["[No manifest provided]" for _ in manifest_schema]]
processed_df = ctx.spark_session.createDataFrame(data, schema_struct)

# Add CDM
site_id_df = site_id_df.join(data_partner_ids, "data_partner_id", "left")
try:
cdm = site_id_df.head().source_cdm
processed_df = processed_df.withColumn("CDM_NAME", F.lit(cdm))
except IndexError:
pass
else:
processed_df = manifest_df

curr_date = datetime.date(datetime.now())
processed_df = processed_df.withColumn("CONTRIBUTION_DATE", F.lit(curr_date).cast("date"))

omop_vocab = omop_vocab.where(omop_vocab["vocabulary_id"] == "None").where(omop_vocab["vocabulary_name"] == "OMOP Standardized Vocabularies")
vocabulary_version = omop_vocab.head().vocabulary_version
processed_df = processed_df.withColumn("N3C_VOCAB_VERSION", F.lit(vocabulary_version).cast("string"))

# Compute approximate expected person count using the CONTROL_MAP file
# Approximate xpected person count = (# rows in CONTROL_MAP) / 2 * 3
control_map_total_count = control_map.count()
approx_expected_person_count = int((control_map_total_count / 2) * 3)
processed_df = processed_df.withColumn("APPROX_EXPECTED_PERSON_COUNT", F.lit(approx_expected_person_count))

return processed_df
171 changes: 171 additions & 0 deletions pipeline_logic/v2/shared-logic/src/source_cdm_utils/parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import csv
from pyspark.sql import Row, functions as F, types as T

header_col = "__is_header__"
errorCols = ["row_number", "error_type", "error_details"]
ErrorRow = Row(header_col, *errorCols)


def required_parse(filename_input, payload_input, domain, clean_output, error_output, all_cols, required_cols):
parse(filename_input, payload_input, domain, clean_output, error_output, all_cols, required_cols)


def optional_parse(filename_input, payload_input, domain, clean_output, error_output, all_cols, required_cols):
parse(filename_input, payload_input, domain, clean_output, error_output, all_cols, required_cols)


def cached_parse(filename_input, payload_input, domain, clean_output, error_output, all_cols, required_cols):
regexPattern = "(?i).*" + domain + "\\.csv"
fs = payload_input.filesystem()
files_df = fs.files(regex=regexPattern)

if files_df.count() > 0:
parse(filename_input, payload_input, domain, clean_output, error_output, all_cols, required_cols)
else:
clean_output.abort()
error_output.abort()


def metadata_parse(payload_input, filename, clean_output, error_output, all_cols, required_cols):
regex = "(?i).*" + filename + "\\.csv"
clean_df, error_df = parse_csv(payload_input, regex, all_cols, required_cols)

clean_output.write_dataframe(clean_df)
error_output.write_dataframe(error_df)


# API above, functionality below


def parse(filename_input, payload_input, domain, clean_output, error_output, all_cols, required_cols):
regex = "(?i).*" + domain + "\\.csv"
clean_df, error_df = parse_csv(payload_input, regex, all_cols, required_cols)

payload_filename = filename_input.dataframe().where(F.col("newest_payload") == True).head().payload # noqa
clean_df = clean_df.withColumn("payload", F.lit(payload_filename))

clean_output.write_dataframe(clean_df)
error_output.write_dataframe(error_df)


def parse_csv(payload_input, regex, all_cols, required_cols):
# Get the correct file from the unzipped payload
files_df = payload_input.filesystem().files(regex=regex)

# Parse the CSV into clean rows and error rows
parser = CsvParser(payload_input, all_cols, required_cols)
result_rdd = files_df.rdd.flatMap(parser)

# The idea behind caching here was that it would make sure Spark didn't parse the CSV twice, once for the clean
# rows and once for the error rows. However some CSVs are greater than 100G, and this line always caused an OOM
# for those. After removing this line, we could parse the 100G even with a small profile.
# result_rdd = result_rdd.cache()

# Separate into good and bad rows
clean_df = rddToDf(result_rdd, "clean", T.StructType([T.StructField(col, T.StringType()) for col in all_cols]))
error_df = rddToDf(result_rdd, "error", T.StructType([T.StructField(col, T.StringType()) for col in errorCols]))

# Return
return (clean_df, error_df)


def rddToDf(inputRdd, rowType, schema):
# Filter by type and get the rows
resultRdd = inputRdd.filter(lambda row: row[0] == rowType)
resultRdd = resultRdd.map(lambda row: row[1])

if resultRdd.isEmpty():
# Convert to DF using the given schema
resultDf = resultRdd.toDF(schema)
else:
# Convert to DF using the RDD Rows' schema
resultDf = resultRdd.toDF()

# Drop the header row - get only the data rows. This is needed to ensure we get the right schema.
resultDf = resultDf.filter(resultDf[header_col] == False).drop(header_col)

return resultDf


class CsvParser():
def __init__(self, rawInput, all_cols, required_cols):
self.rawInput = rawInput
self.all_cols = all_cols
self.required_cols = required_cols

def __call__(self, csvFilePath):
try:
dialect = self.determineDialect(csvFilePath)
except Exception as e:
yield ("error", ErrorRow(False, "0", "Could not determine the CSV dialect", repr(e)))
return

with self.rawInput.filesystem().open(csvFilePath.path, errors='ignore') as csvFile:
csvReader = csv.reader(csvFile, dialect=dialect)
yield from self.parseHeader(csvReader)
yield from self.parseFile(csvReader)

def determineDialect(self, csvFilePath):
with self.rawInput.filesystem().open(csvFilePath.path, errors='ignore') as csvFile:
dialect = csv.Sniffer().sniff(csvFile.readline(), delimiters=",|")

return dialect

def parseHeader(self, csvReader):
header = next(csvReader)
header = [x.strip().strip("\ufeff").upper() for x in header]
header = [*filter(lambda col: col, header)] # Remove empty headers

self.CleanRow = Row(header_col, *header)
self.expected_num_fields = len(header)

yield ("clean", self.CleanRow(True, *header))
yield ("error", ErrorRow(True, "", "", ""))

warningDetails = {
"all columns": self.all_cols,
"required columns": self.required_cols,
"header": header
}

# Throw warning for every column in the required schema but not in the header
for col in self.required_cols:
if col not in header:
message = f"Header did not contain required column `{col}`"
yield ("error", ErrorRow(False, "0", message, warningDetails))

# Throw warning for every column in the header but not in the schema
for col in header:
if col not in self.all_cols:
message = f"Header contained unexpected extra column `{col}`"
yield ("error", ErrorRow(False, "0", message, warningDetails))

def parseFile(self, csvReader):
i = 0
while True:
i += 1

nextError = False
try:
row = next(csvReader)
except StopIteration:
break
except Exception as e:
nextError = [str(i), "Unparsable row", repr(e)]

# Hit an error parsing
if nextError:
yield ("error", ErrorRow(False, *nextError))

# Properly formatted row
elif len(row) == self.expected_num_fields:
yield ("clean", self.CleanRow(False, *row))

# Ignore empty rows/extra newlines
elif not row:
continue

# Improperly formatted row
else:
message = f"Incorrect number of fields. Expected {str(self.expected_num_fields)} but found {str(len(row))}."
yield ("error", ErrorRow(False, str(i), message, str(row)))
Loading

0 comments on commit b516b60

Please sign in to comment.