Skip to content

Commit

Permalink
refactor(source): extract avro inner schema precisely (#19701)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangjinwu authored Dec 9, 2024
1 parent fcac311 commit f3ed1de
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 99 deletions.
6 changes: 5 additions & 1 deletion e2e_test/source_inline/kafka/avro/ref.slt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ select
(bar).b.y
from s;
----
3 4 5 6 6 5 4 3
3 4 5 6 NULL NULL NULL NULL

# Parsing of column `bar` fails even with ints because now `schema` is required.
# This will be fully supported in the next PR
# 3 4 5 6 6 5 4 3


statement ok
Expand Down
128 changes: 47 additions & 81 deletions src/connector/codec/src/decoder/avro/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
// limitations under the License.

mod schema;
use std::sync::LazyLock;

use apache_avro::schema::{DecimalSchema, RecordSchema, UnionSchema};
use apache_avro::schema::{DecimalSchema, UnionSchema};
use apache_avro::types::{Value, ValueKind};
use apache_avro::{Decimal as AvroDecimal, Schema};
use chrono::Datelike;
use itertools::Itertools;
use num_bigint::{BigInt, Sign};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::bail;
use risingwave_common::log::LogSuppresser;
use risingwave_common::types::{
DataType, Date, DatumCow, Interval, JsonbVal, MapValue, ScalarImpl, Time, Timestamp,
Timestamptz, ToOwnedDatum,
Expand All @@ -42,37 +39,22 @@ pub struct AvroParseOptions<'a> {
///
/// FIXME: In theory we should use resolved schema.
/// e.g., it's possible that a field is a reference to a decimal or a record containing a decimal field.
pub schema: Option<&'a Schema>,
schema: &'a Schema,
/// Strict Mode
/// If strict mode is disabled, an int64 can be parsed from an `AvroInt` (int32) value.
pub relax_numeric: bool,
relax_numeric: bool,
}

impl<'a> AvroParseOptions<'a> {
pub fn create(schema: &'a Schema) -> Self {
Self {
schema: Some(schema),
schema,
relax_numeric: true,
}
}
}

impl<'a> AvroParseOptions<'a> {
fn extract_inner_schema(&self, key: Option<&str>) -> Option<&'a Schema> {
self.schema
.map(|schema| avro_extract_field_schema(schema, key))
.transpose()
.map_err(|_err| {
static LOG_SUPPERSSER: LazyLock<LogSuppresser> =
LazyLock::new(LogSuppresser::default);
if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
tracing::error!(suppressed_count, "extract sub-schema");
}
})
.ok()
.flatten()
}

/// Parse an avro value into expected type.
///
/// 3 kinds of type info are used to parsing:
Expand All @@ -86,7 +68,7 @@ impl<'a> AvroParseOptions<'a> {
/// `type_expected`, converting the value if possible.
/// - If only value is provided (without schema and `type_expected`),
/// the `DataType` will be inferred.
pub fn convert_to_datum<'b>(
fn convert_to_datum<'b>(
&self,
value: &'b Value,
type_expected: &DataType,
Expand All @@ -110,15 +92,15 @@ impl<'a> AvroParseOptions<'a> {
(_, Value::Null) => return Ok(DatumCow::NULL),
// ---- Union (with >=2 non null variants), and nullable Union ([null, record]) -----
(DataType::Struct(struct_type_info), Value::Union(variant, v)) => {
let Some(Schema::Union(u)) = self.schema else {
let Schema::Union(u) = self.schema else {
// XXX: Is this branch actually unreachable? (if self.schema is correctly used)
return Err(create_error());
};

if let Some(inner) = get_nullable_union_inner(u) {
// nullable Union ([null, record])
return Self {
schema: Some(inner),
schema: inner,
relax_numeric: self.relax_numeric,
}
.convert_to_datum(v, type_expected);
Expand All @@ -139,7 +121,7 @@ impl<'a> AvroParseOptions<'a> {
for (field_name, field_type) in struct_type_info.iter() {
if field_name == expected_field_name {
let datum = Self {
schema: Some(variant_schema),
schema: variant_schema,
relax_numeric: self.relax_numeric,
}
.convert_to_datum(v, field_type)?
Expand All @@ -154,7 +136,12 @@ impl<'a> AvroParseOptions<'a> {
}
// nullable Union ([null, T])
(_, Value::Union(_, v)) => {
let schema = self.extract_inner_schema(None);
let Schema::Union(u) = self.schema else {
return Err(create_error());
};
let Some(schema) = get_nullable_union_inner(u) else {
return Err(create_error());
};
return Self {
schema,
relax_numeric: self.relax_numeric,
Expand Down Expand Up @@ -182,9 +169,9 @@ impl<'a> AvroParseOptions<'a> {
// ---- Decimal -----
(DataType::Decimal, Value::Decimal(avro_decimal)) => {
let (precision, scale) = match self.schema {
Some(Schema::Decimal(DecimalSchema {
Schema::Decimal(DecimalSchema {
precision, scale, ..
})) => (*precision, *scale),
}) => (*precision, *scale),
_ => Err(create_error())?,
};
let decimal = avro_decimal_to_rust_decimal(avro_decimal.clone(), precision, scale)
Expand Down Expand Up @@ -266,14 +253,17 @@ impl<'a> AvroParseOptions<'a> {
ScalarImpl::Interval(Interval::from_month_day_usec(months, days, usecs))
}
// ---- Struct -----
(DataType::Struct(struct_type_info), Value::Record(descs)) => StructValue::new(
(DataType::Struct(struct_type_info), Value::Record(descs)) => StructValue::new({
let Schema::Record(record_schema) = &self.schema else {
return Err(create_error());
};
struct_type_info
.names()
.zip_eq_fast(struct_type_info.types())
.map(|(field_name, field_type)| {
let maybe_value = descs.iter().find(|(k, _v)| k == field_name);
if let Some((_, value)) = maybe_value {
let schema = self.extract_inner_schema(Some(field_name));
if let Some(idx) = record_schema.lookup.get(field_name) {
let value = &descs[*idx].1;
let schema = &record_schema.fields[*idx].schema;
Ok(Self {
schema,
relax_numeric: self.relax_numeric,
Expand All @@ -284,12 +274,15 @@ impl<'a> AvroParseOptions<'a> {
Ok(None)
}
})
.collect::<Result<_, AccessError>>()?,
)
.collect::<Result<_, AccessError>>()?
})
.into(),
// ---- List -----
(DataType::List(item_type), Value::Array(array)) => ListValue::new({
let schema = self.extract_inner_schema(None);
let Schema::Array(element_schema) = &self.schema else {
return Err(create_error());
};
let schema = element_schema;
let mut builder = item_type.create_array_builder(array.len());
for v in array {
let value = Self {
Expand All @@ -316,7 +309,10 @@ impl<'a> AvroParseOptions<'a> {
uuid.as_hyphenated().to_string().into_boxed_str().into()
}
(DataType::Map(map_type), Value::Map(map)) => {
let schema = self.extract_inner_schema(None);
let Schema::Map(value_schema) = &self.schema else {
return Err(create_error());
};
let schema = value_schema;
let mut builder = map_type
.clone()
.into_struct()
Expand Down Expand Up @@ -405,13 +401,22 @@ impl Access for AvroAccess<'_> {
// },
// ...]
value = v;
options.schema = options.extract_inner_schema(None);
let Schema::Union(u) = options.schema else {
return Err(create_error());
};
let Some(schema) = get_nullable_union_inner(u) else {
return Err(create_error());
};
options.schema = schema;
continue;
}
Value::Record(fields) => {
if let Some((_, v)) = fields.iter().find(|(k, _)| k == key) {
value = v;
options.schema = options.extract_inner_schema(Some(key));
let Schema::Record(record_schema) = &options.schema else {
return Err(create_error());
};
if let Some(idx) = record_schema.lookup.get(key) {
value = &fields[*idx].1;
options.schema = &record_schema.fields[*idx].schema;
i += 1;
continue;
}
Expand Down Expand Up @@ -444,7 +449,7 @@ pub(crate) fn avro_decimal_to_rust_decimal(
}

/// If the union schema is `[null, T]` or `[T, null]`, returns `Some(T)`; otherwise returns `None`.
fn get_nullable_union_inner(union_schema: &UnionSchema) -> Option<&'_ Schema> {
pub fn get_nullable_union_inner(union_schema: &UnionSchema) -> Option<&'_ Schema> {
let variants = union_schema.variants();
// Note: `[null, null] is invalid`, we don't need to worry about that.
if variants.len() == 2 && variants.contains(&Schema::Null) {
Expand All @@ -458,45 +463,6 @@ fn get_nullable_union_inner(union_schema: &UnionSchema) -> Option<&'_ Schema> {
}
}

pub fn avro_schema_skip_nullable_union(schema: &Schema) -> anyhow::Result<&Schema> {
match schema {
Schema::Union(union_schema) => match get_nullable_union_inner(union_schema) {
Some(s) => Ok(s),
None => Err(anyhow::format_err!(
"illegal avro union schema, expected [null, T], got {:?}",
union_schema
)),
},
other => Ok(other),
}
}

// extract inner filed/item schema of record/array/union
pub fn avro_extract_field_schema<'a>(
schema: &'a Schema,
name: Option<&str>,
) -> anyhow::Result<&'a Schema> {
match schema {
Schema::Record(RecordSchema { fields, lookup, .. }) => {
let name =
name.ok_or_else(|| anyhow::format_err!("no name provided for a field in record"))?;
let index = lookup.get(name).ok_or_else(|| {
anyhow::format_err!("no field named '{}' in record: {:?}", name, schema)
})?;
let field = fields
.get(*index)
.ok_or_else(|| anyhow::format_err!("illegal avro record schema {:?}", schema))?;
Ok(&field.schema)
}
Schema::Array(schema) => Ok(schema),
// Only nullable union should be handled here.
// We will not extract inner schema for real union (and it's not extractable).
Schema::Union(_) => avro_schema_skip_nullable_union(schema),
Schema::Map(schema) => Ok(schema),
_ => bail!("avro schema does not have inner item, schema: {:?}", schema),
}
}

pub(crate) fn unix_epoch_days() -> i32 {
Date::from_ymd_uncheck(1970, 1, 1).0.num_days_from_ce()
}
Expand Down
42 changes: 25 additions & 17 deletions src/connector/src/parser/debezium/avro_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
use std::fmt::Debug;
use std::sync::Arc;

use anyhow::Context;
use apache_avro::types::Value;
use apache_avro::{from_avro_datum, Schema};
use risingwave_common::try_match_expand;
use risingwave_connector_codec::decoder::avro::{
avro_extract_field_schema, avro_schema_skip_nullable_union, avro_schema_to_column_descs,
AvroAccess, AvroParseOptions, ResolvedAvroSchema,
avro_schema_to_column_descs, get_nullable_union_inner, AvroAccess, AvroParseOptions,
ResolvedAvroSchema,
};
use risingwave_pb::plan_common::ColumnDesc;

Expand Down Expand Up @@ -162,20 +163,33 @@ impl DebeziumAvroParserConfig {
// See <https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-events>

avro_schema_to_column_descs(
avro_schema_skip_nullable_union(avro_extract_field_schema(
// FIXME: use resolved schema here.
// Currently it works because "after" refers to a subtree in "before",
// but in theory, inside "before" there could also be a reference.
&self.outer_schema,
Some("before"),
)?)?,
// This assumes no external `Ref`s (e.g. "before" referring to "after" or "source").
// Internal `Ref`s inside the "before" tree are allowed.
extract_debezium_table_schema(&self.outer_schema)?,
// TODO: do we need to support map type here?
None,
)
.map_err(Into::into)
}
}

fn extract_debezium_table_schema(root: &Schema) -> anyhow::Result<&Schema> {
let Schema::Record(root_record) = root else {
anyhow::bail!("Root schema of debezium shall be a record but got: {root:?}");
};
let idx = (root_record.lookup.get("before"))
.context("Root schema of debezium shall contain \"before\" field.")?;
let schema = &root_record.fields[*idx].schema;
// It is wrapped inside a union to allow null, so we look inside.
let Schema::Union(union_schema) = schema else {
return Ok(schema);
};
get_nullable_union_inner(union_schema).context(format!(
"illegal avro union schema, expected [null, T], got {:?}",
union_schema
))
}

#[cfg(test)]
mod tests {
use std::io::Read;
Expand Down Expand Up @@ -264,10 +278,7 @@ mod tests {

let outer_schema = get_outer_schema();
let expected_inner_schema = Schema::parse_str(inner_shema_str).unwrap();
let extracted_inner_schema = avro_schema_skip_nullable_union(
avro_extract_field_schema(&outer_schema, Some("before")).unwrap(),
)
.unwrap();
let extracted_inner_schema = extract_debezium_table_schema(&outer_schema).unwrap();
assert_eq!(&expected_inner_schema, extracted_inner_schema);
}

Expand Down Expand Up @@ -355,10 +366,7 @@ mod tests {
fn test_map_to_columns() {
let outer_schema = get_outer_schema();
let columns = avro_schema_to_column_descs(
avro_schema_skip_nullable_union(
avro_extract_field_schema(&outer_schema, Some("before")).unwrap(),
)
.unwrap(),
extract_debezium_table_schema(&outer_schema).unwrap(),
None,
)
.unwrap()
Expand Down

0 comments on commit f3ed1de

Please sign in to comment.