diff --git a/Cargo.lock b/Cargo.lock index 8ffae03a..9157fbe5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -448,6 +448,7 @@ dependencies = [ "mongodb-support", "ndc-models", "ndc-query-plan", + "ref-cast", "schemars", "serde", "serde_json", @@ -475,6 +476,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "convert_case" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -637,13 +647,42 @@ version = "0.99.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" dependencies = [ - "convert_case", + "convert_case 0.4.0", "proc-macro2", "quote", "rustc_version 0.4.0", "syn 1.0.109", ] +[[package]] +name = "deriving-via-impl" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ed8bf3147663d533313857a62e60f1b23f680992b79defe99211fc65afadcb4" +dependencies = [ + "convert_case 0.6.0", + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "deriving_via" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99061ea972ed08b607ac4769035e05c0c48a78a23e7088220dd1c336e026d1e9" +dependencies = [ + "deriving-via-impl", + "itertools", + "proc-macro-error", + "proc-macro2", + "quote", + "strum", + "strum_macros", + "syn 2.0.66", + "typed-builder 0.18.2", +] + [[package]] name = "diff" version = "0.1.13" @@ -1625,6 +1664,12 @@ dependencies = [ "unicase", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.3" @@ -1766,6 +1811,7 @@ dependencies = [ "anyhow", "clap", "configuration", + "deriving_via", "futures-util", "indexmap 2.2.6", "itertools", @@ -1773,6 +1819,8 @@ dependencies = [ "mongodb-agent-common", "mongodb-support", "ndc-models", + "nom", + "pretty_assertions", "proptest", "serde", "serde_json", @@ -1938,6 +1986,16 @@ dependencies = [ "smol_str", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nonempty" version = "0.10.0" @@ -2288,6 +2346,30 @@ dependencies = [ "yansi", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.85" @@ -3129,6 +3211,25 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.66", +] + [[package]] name = "subtle" version = "2.5.0" @@ -3735,6 +3836,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4259d9d4425d9f0661581b804cb85fe66a4c631cadd8f490d1c13a35d5d9291" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-width" version = "0.1.13" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 031d7891..40b77c19 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -3,6 +3,9 @@ name = "mongodb-cli-plugin" edition = "2021" version.workspace = true +[features] +native-query-subcommand = [] + [dependencies] configuration = { path = "../configuration" } mongodb-agent-common = { path = "../mongodb-agent-common" } @@ -11,16 +14,18 @@ mongodb-support = { path = "../mongodb-support" } anyhow = "1.0.80" clap = { version = "4.5.1", features = ["derive", "env"] } +deriving_via = "^1.6.1" futures-util = "0.3.28" indexmap = { workspace = true } itertools = { workspace = true } ndc-models = { workspace = true } +nom = "^7.1.3" serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0.113", features = ["raw_value"] } thiserror = "1.0.57" tokio = { version = "1.36.0", features = ["full"] } [dev-dependencies] -test-helpers = { path = "../test-helpers" } - +pretty_assertions = "1" proptest = "1" +test-helpers = { path = "../test-helpers" } diff --git a/crates/cli/src/exit_codes.rs b/crates/cli/src/exit_codes.rs new file mode 100644 index 00000000..a0015264 --- /dev/null +++ b/crates/cli/src/exit_codes.rs @@ -0,0 +1,18 @@ +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ExitCode { + CouldNotReadAggregationPipeline, + CouldNotReadConfiguration, + ErrorWriting, + RefusedToOverwrite, +} + +impl From for i32 { + fn from(value: ExitCode) -> Self { + match value { + ExitCode::CouldNotReadAggregationPipeline => 201, + ExitCode::CouldNotReadConfiguration => 202, + ExitCode::ErrorWriting => 204, + ExitCode::RefusedToOverwrite => 203, + } + } +} diff --git a/crates/cli/src/introspection/sampling.rs b/crates/cli/src/introspection/sampling.rs index c01360ca..f027c01b 100644 --- a/crates/cli/src/introspection/sampling.rs +++ b/crates/cli/src/introspection/sampling.rs @@ -94,7 +94,7 @@ async fn sample_schema_from_collection( } } -fn make_object_type( +pub fn make_object_type( object_type_name: &ndc_models::ObjectTypeName, document: &Document, is_collection_type: bool, diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index 1baef324..0e4e81a8 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -1,15 +1,21 @@ //! The interpretation of the commands that the CLI can handle. +mod exit_codes; mod introspection; mod logging; +#[cfg(feature = "native-query-subcommand")] +mod native_query; + use std::path::PathBuf; use clap::{Parser, Subcommand}; // Exported for use in tests pub use introspection::type_from_bson; -use mongodb_agent_common::state::ConnectorState; +use mongodb_agent_common::state::try_init_state_from_uri; +#[cfg(feature = "native-query-subcommand")] +pub use native_query::native_query_from_pipeline; #[derive(Debug, Clone, Parser)] pub struct UpdateArgs { @@ -28,23 +34,32 @@ pub struct UpdateArgs { pub enum Command { /// Update the configuration by introspecting the database, using the configuration options. Update(UpdateArgs), + + #[cfg(feature = "native-query-subcommand")] + #[command(subcommand)] + NativeQuery(native_query::Command), } pub struct Context { pub path: PathBuf, - pub connector_state: ConnectorState, + pub connection_uri: Option, } /// Run a command in a given directory. pub async fn run(command: Command, context: &Context) -> anyhow::Result<()> { match command { Command::Update(args) => update(context, &args).await?, + + #[cfg(feature = "native-query-subcommand")] + Command::NativeQuery(command) => native_query::run(context, command).await?, }; Ok(()) } /// Update the configuration in the current directory by introspecting the database. async fn update(context: &Context, args: &UpdateArgs) -> anyhow::Result<()> { + let connector_state = try_init_state_from_uri(context.connection_uri.as_ref()).await?; + let configuration_options = configuration::parse_configuration_options_file(&context.path).await; // Prefer arguments passed to cli, and fallback to the configuration file @@ -72,7 +87,7 @@ async fn update(context: &Context, args: &UpdateArgs) -> anyhow::Result<()> { if !no_validator_schema { let schemas_from_json_validation = - introspection::get_metadata_from_validation_schema(&context.connector_state).await?; + introspection::get_metadata_from_validation_schema(&connector_state).await?; configuration::write_schema_directory(&context.path, schemas_from_json_validation).await?; } @@ -81,7 +96,7 @@ async fn update(context: &Context, args: &UpdateArgs) -> anyhow::Result<()> { sample_size, all_schema_nullable, config_file_changed, - &context.connector_state, + &connector_state, &existing_schemas, ) .await?; diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 9b1752e4..20b508b9 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -3,12 +3,11 @@ //! This is intended to be automatically downloaded and invoked via the Hasura CLI, as a plugin. //! It is unlikely that end-users will use it directly. -use anyhow::anyhow; use std::env; use std::path::PathBuf; use clap::{Parser, ValueHint}; -use mongodb_agent_common::state::{try_init_state_from_uri, DATABASE_URI_ENV_VAR}; +use mongodb_agent_common::state::DATABASE_URI_ENV_VAR; use mongodb_cli_plugin::{run, Command, Context}; /// The command-line arguments. @@ -17,6 +16,7 @@ pub struct Args { /// The path to the configuration. Defaults to the current directory. #[arg( long = "context-path", + short = 'p', env = "HASURA_PLUGIN_CONNECTOR_CONTEXT_PATH", value_name = "DIRECTORY", value_hint = ValueHint::DirPath @@ -46,16 +46,9 @@ pub async fn main() -> anyhow::Result<()> { Some(path) => path, None => env::current_dir()?, }; - let connection_uri = args.connection_uri.ok_or(anyhow!( - "Missing environment variable {}", - DATABASE_URI_ENV_VAR - ))?; - let connector_state = try_init_state_from_uri(&connection_uri) - .await - .map_err(|e| anyhow!("Error initializing MongoDB state {}", e))?; let context = Context { path, - connector_state, + connection_uri: args.connection_uri, }; run(args.subcommand, &context).await?; Ok(()) diff --git a/crates/cli/src/native_query/aggregation_expression.rs b/crates/cli/src/native_query/aggregation_expression.rs new file mode 100644 index 00000000..16dc65dc --- /dev/null +++ b/crates/cli/src/native_query/aggregation_expression.rs @@ -0,0 +1,131 @@ +use std::collections::BTreeMap; +use std::iter::once; + +use configuration::schema::{ObjectField, ObjectType, Type}; +use itertools::Itertools as _; +use mongodb::bson::{Bson, Document}; +use mongodb_support::BsonScalarType; + +use super::helpers::nested_field_type; +use super::pipeline_type_context::PipelineTypeContext; + +use super::error::{Error, Result}; +use super::reference_shorthand::{parse_reference_shorthand, Reference}; + +pub fn infer_type_from_aggregation_expression( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + bson: Bson, +) -> Result { + let t = match bson { + Bson::Double(_) => Type::Scalar(BsonScalarType::Double), + Bson::String(string) => infer_type_from_reference_shorthand(context, &string)?, + Bson::Array(_) => todo!("array type"), + Bson::Document(doc) => { + infer_type_from_aggregation_expression_document(context, desired_object_type_name, doc)? + } + Bson::Boolean(_) => todo!(), + Bson::Null => todo!(), + Bson::RegularExpression(_) => todo!(), + Bson::JavaScriptCode(_) => todo!(), + Bson::JavaScriptCodeWithScope(_) => todo!(), + Bson::Int32(_) => todo!(), + Bson::Int64(_) => todo!(), + Bson::Timestamp(_) => todo!(), + Bson::Binary(_) => todo!(), + Bson::ObjectId(_) => todo!(), + Bson::DateTime(_) => todo!(), + Bson::Symbol(_) => todo!(), + Bson::Decimal128(_) => todo!(), + Bson::Undefined => todo!(), + Bson::MaxKey => todo!(), + Bson::MinKey => todo!(), + Bson::DbPointer(_) => todo!(), + }; + Ok(t) +} + +fn infer_type_from_aggregation_expression_document( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + mut document: Document, +) -> Result { + let mut expression_operators = document + .keys() + .filter(|key| key.starts_with("$")) + .collect_vec(); + let expression_operator = expression_operators.pop().map(ToString::to_string); + let is_empty = expression_operators.is_empty(); + match (expression_operator, is_empty) { + (_, false) => Err(Error::MultipleExpressionOperators(document)), + (Some(operator), _) => { + let operands = document.remove(&operator).unwrap(); + infer_type_from_operator_expression( + context, + desired_object_type_name, + &operator, + operands, + ) + } + (None, _) => infer_type_from_document(context, desired_object_type_name, document), + } +} + +fn infer_type_from_operator_expression( + _context: &mut PipelineTypeContext<'_>, + _desired_object_type_name: &str, + operator: &str, + operands: Bson, +) -> Result { + let t = match (operator, operands) { + ("$split", _) => Type::ArrayOf(Box::new(Type::Scalar(BsonScalarType::String))), + (op, _) => Err(Error::UnknownAggregationOperator(op.to_string()))?, + }; + Ok(t) +} + +/// This is a document that is not evaluated as a plain value, not as an aggregation expression. +fn infer_type_from_document( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + document: Document, +) -> Result { + let object_type_name = context.unique_type_name(desired_object_type_name); + let fields = document + .into_iter() + .map(|(field_name, bson)| { + let field_object_type_name = format!("{desired_object_type_name}_{field_name}"); + let object_field_type = + infer_type_from_aggregation_expression(context, &field_object_type_name, bson)?; + let object_field = ObjectField { + r#type: object_field_type, + description: None, + }; + Ok((field_name.into(), object_field)) + }) + .collect::>>()?; + let object_type = ObjectType { + fields, + description: None, + }; + context.insert_object_type(object_type_name.clone(), object_type); + Ok(Type::Object(object_type_name.into())) +} + +pub fn infer_type_from_reference_shorthand( + context: &mut PipelineTypeContext<'_>, + input: &str, +) -> Result { + let reference = parse_reference_shorthand(input)?; + let t = match reference { + Reference::NativeQueryVariable { .. } => todo!(), + Reference::PipelineVariable { .. } => todo!(), + Reference::InputDocumentField { name, nested_path } => { + let doc_type = context.get_input_document_type_name()?; + let path = once(&name).chain(&nested_path); + nested_field_type(context, doc_type.to_string(), path)? + } + Reference::String => Type::Scalar(BsonScalarType::String), + }; + Ok(t) +} diff --git a/crates/cli/src/native_query/error.rs b/crates/cli/src/native_query/error.rs new file mode 100644 index 00000000..11be9841 --- /dev/null +++ b/crates/cli/src/native_query/error.rs @@ -0,0 +1,68 @@ +use configuration::schema::Type; +use mongodb::bson::{self, Bson, Document}; +use ndc_models::{FieldName, ObjectTypeName}; +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Clone, Debug, Error)] +pub enum Error { + #[error("Cannot infer a result type for an empty pipeline")] + EmptyPipeline, + + #[error( + "Expected {reference} to reference an array, but instead it references a {referenced_type:?}" + )] + ExpectedArrayReference { + reference: Bson, + referenced_type: Type, + }, + + #[error("Expected an object type, but got: {actual_type:?}")] + ExpectedObject { actual_type: Type }, + + #[error("Expected a path for the $unwind stage")] + ExpectedStringPath(Bson), + + #[error( + "Cannot infer a result document type for pipeline because it does not produce documents" + )] + IncompletePipeline, + + #[error("An object representing an expression must have exactly one field: {0}")] + MultipleExpressionOperators(Document), + + #[error("Object type, {object_type}, does not have a field named {field_name}")] + ObjectMissingField { + object_type: ObjectTypeName, + field_name: FieldName, + }, + + #[error("Type mismatch in {context}: expected {expected:?}, but got {actual:?}")] + TypeMismatch { + context: String, + expected: String, + actual: Bson, + }, + + #[error("Cannot infer a result type for this pipeline. But you can create a native query by writing the configuration file by hand.")] + UnableToInferResultType, + + #[error("Error parsing a string in the aggregation pipeline: {0}")] + UnableToParseReferenceShorthand(String), + + #[error("Unknown aggregation operator: {0}")] + UnknownAggregationOperator(String), + + #[error("Type inference is not currently implemented for stage {stage_index} in the aggregation pipeline. Please file a bug report, and declare types for your native query by hand.\n\n{stage}")] + UnknownAggregationStage { + stage_index: usize, + stage: bson::Document, + }, + + #[error("Native query input collection, \"{0}\", is not defined in the connector schema")] + UnknownCollection(String), + + #[error("Unknown object type, \"{0}\"")] + UnknownObjectType(String), +} diff --git a/crates/cli/src/native_query/helpers.rs b/crates/cli/src/native_query/helpers.rs new file mode 100644 index 00000000..052c4297 --- /dev/null +++ b/crates/cli/src/native_query/helpers.rs @@ -0,0 +1,54 @@ +use configuration::{schema::Type, Configuration}; +use ndc_models::{CollectionInfo, CollectionName, FieldName, ObjectTypeName}; + +use super::{ + error::{Error, Result}, + pipeline_type_context::PipelineTypeContext, +}; + +fn find_collection<'a>( + configuration: &'a Configuration, + collection_name: &CollectionName, +) -> Result<&'a CollectionInfo> { + if let Some(collection) = configuration.collections.get(collection_name) { + return Ok(collection); + } + if let Some((_, function)) = configuration.functions.get(collection_name) { + return Ok(function); + } + + Err(Error::UnknownCollection(collection_name.to_string())) +} + +pub fn find_collection_object_type( + configuration: &Configuration, + collection_name: &CollectionName, +) -> Result { + let collection = find_collection(configuration, collection_name)?; + Ok(collection.collection_type.clone()) +} + +/// Looks up the given object type, and traverses the given field path to get the type of the +/// referenced field. If `nested_path` is empty returns the type of the original object. +pub fn nested_field_type<'a>( + context: &PipelineTypeContext<'_>, + object_type_name: String, + nested_path: impl IntoIterator, +) -> Result { + let mut parent_type = Type::Object(object_type_name); + for path_component in nested_path { + if let Type::Object(type_name) = parent_type { + let object_type = context + .get_object_type(&type_name.clone().into()) + .ok_or_else(|| Error::UnknownObjectType(type_name.clone()))?; + let field = object_type.fields.get(path_component).ok_or_else(|| { + Error::ObjectMissingField { + object_type: type_name.into(), + field_name: path_component.clone(), + } + })?; + parent_type = field.r#type.clone(); + } + } + Ok(parent_type) +} diff --git a/crates/cli/src/native_query/infer_result_type.rs b/crates/cli/src/native_query/infer_result_type.rs new file mode 100644 index 00000000..eb5c8b02 --- /dev/null +++ b/crates/cli/src/native_query/infer_result_type.rs @@ -0,0 +1,475 @@ +use std::{collections::BTreeMap, iter::once}; + +use configuration::{ + schema::{ObjectField, ObjectType, Type}, + Configuration, +}; +use mongodb::bson::{Bson, Document}; +use mongodb_support::{ + aggregate::{Accumulator, Pipeline, Stage}, + BsonScalarType, +}; +use ndc_models::{CollectionName, FieldName, ObjectTypeName}; + +use crate::introspection::{sampling::make_object_type, type_unification::unify_object_types}; + +use super::{ + aggregation_expression::{ + self, infer_type_from_aggregation_expression, infer_type_from_reference_shorthand, + }, + error::{Error, Result}, + helpers::find_collection_object_type, + pipeline_type_context::{PipelineTypeContext, PipelineTypes}, + reference_shorthand::{parse_reference_shorthand, Reference}, +}; + +type ObjectTypes = BTreeMap; + +pub fn infer_result_type( + configuration: &Configuration, + // If we have to define a new object type, use this name + desired_object_type_name: &str, + input_collection: Option<&CollectionName>, + pipeline: &Pipeline, +) -> Result { + let collection_doc_type = input_collection + .map(|collection_name| find_collection_object_type(configuration, collection_name)) + .transpose()?; + let mut stages = pipeline.iter().enumerate(); + let mut context = PipelineTypeContext::new(configuration, collection_doc_type); + match stages.next() { + Some((stage_index, stage)) => infer_result_type_helper( + &mut context, + desired_object_type_name, + stage_index, + stage, + stages, + ), + None => Err(Error::EmptyPipeline), + }?; + context.try_into() +} + +pub fn infer_result_type_helper<'a, 'b>( + context: &mut PipelineTypeContext<'a>, + desired_object_type_name: &str, + stage_index: usize, + stage: &Stage, + mut rest: impl Iterator, +) -> Result<()> { + match stage { + Stage::Documents(docs) => { + let document_type_name = + context.unique_type_name(&format!("{desired_object_type_name}_documents")); + let new_object_types = infer_type_from_documents(&document_type_name, docs); + context.set_stage_doc_type(document_type_name, new_object_types); + } + Stage::Match(_) => (), + Stage::Sort(_) => (), + Stage::Limit(_) => (), + Stage::Lookup { .. } => todo!("lookup stage"), + Stage::Skip(_) => (), + Stage::Group { + key_expression, + accumulators, + } => { + let object_type_name = infer_type_from_group_stage( + context, + desired_object_type_name, + key_expression, + accumulators, + )?; + context.set_stage_doc_type(object_type_name, Default::default()) + } + Stage::Facet(_) => todo!("facet stage"), + Stage::Count(_) => todo!("count stage"), + Stage::ReplaceWith(selection) => { + let selection: &Document = selection.into(); + let result_type = aggregation_expression::infer_type_from_aggregation_expression( + context, + desired_object_type_name, + selection.clone().into(), + )?; + match result_type { + Type::Object(object_type_name) => { + context.set_stage_doc_type(object_type_name.into(), Default::default()); + } + t => Err(Error::ExpectedObject { actual_type: t })?, + } + } + Stage::Unwind { + path, + include_array_index, + preserve_null_and_empty_arrays, + } => { + let result_type = infer_type_from_unwind_stage( + context, + desired_object_type_name, + path, + include_array_index.as_deref(), + *preserve_null_and_empty_arrays, + )?; + context.set_stage_doc_type(result_type, Default::default()) + } + Stage::Other(doc) => { + let warning = Error::UnknownAggregationStage { + stage_index, + stage: doc.clone(), + }; + context.set_unknown_stage_doc_type(warning); + } + }; + match rest.next() { + Some((next_stage_index, next_stage)) => infer_result_type_helper( + context, + desired_object_type_name, + next_stage_index, + next_stage, + rest, + ), + None => Ok(()), + } +} + +pub fn infer_type_from_documents( + object_type_name: &ObjectTypeName, + documents: &[Document], +) -> ObjectTypes { + let mut collected_object_types = vec![]; + for document in documents { + let object_types = make_object_type(object_type_name, document, false, false); + collected_object_types = if collected_object_types.is_empty() { + object_types + } else { + unify_object_types(collected_object_types, object_types) + }; + } + collected_object_types + .into_iter() + .map(|type_with_name| (type_with_name.name, type_with_name.value)) + .collect() +} + +fn infer_type_from_group_stage( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + key_expression: &Bson, + accumulators: &BTreeMap, +) -> Result { + let group_key_expression_type = infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_id"), + key_expression.clone(), + )?; + + let group_expression_field: (FieldName, ObjectField) = ( + "_id".into(), + ObjectField { + r#type: group_key_expression_type.clone(), + description: None, + }, + ); + let accumulator_fields = accumulators.iter().map(|(key, accumulator)| { + let accumulator_type = match accumulator { + Accumulator::Count => Type::Scalar(BsonScalarType::Int), + Accumulator::Min(expr) => infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_min"), + expr.clone(), + )?, + Accumulator::Max(expr) => infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_min"), + expr.clone(), + )?, + Accumulator::Push(expr) => { + let t = infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_push"), + expr.clone(), + )?; + Type::ArrayOf(Box::new(t)) + } + Accumulator::Avg(expr) => { + let t = infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_avg"), + expr.clone(), + )?; + match t { + Type::ExtendedJSON => t, + Type::Scalar(scalar_type) if scalar_type.is_numeric() => t, + _ => Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + } + } + Accumulator::Sum(expr) => { + let t = infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_push"), + expr.clone(), + )?; + match t { + Type::ExtendedJSON => t, + Type::Scalar(scalar_type) if scalar_type.is_numeric() => t, + _ => Type::Scalar(BsonScalarType::Int), + } + } + }; + Ok::<_, Error>(( + key.clone().into(), + ObjectField { + r#type: accumulator_type, + description: None, + }, + )) + }); + let fields = once(Ok(group_expression_field)) + .chain(accumulator_fields) + .collect::>()?; + + let object_type = ObjectType { + fields, + description: None, + }; + let object_type_name = context.unique_type_name(desired_object_type_name); + context.insert_object_type(object_type_name.clone(), object_type); + Ok(object_type_name) +} + +fn infer_type_from_unwind_stage( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + path: &str, + include_array_index: Option<&str>, + _preserve_null_and_empty_arrays: Option, +) -> Result { + let field_to_unwind = parse_reference_shorthand(path)?; + let Reference::InputDocumentField { name, nested_path } = field_to_unwind else { + return Err(Error::ExpectedStringPath(path.into())); + }; + + let field_type = infer_type_from_reference_shorthand(context, path)?; + let Type::ArrayOf(field_element_type) = field_type else { + return Err(Error::ExpectedArrayReference { + reference: path.into(), + referenced_type: field_type, + }); + }; + + let nested_path_iter = nested_path.into_iter(); + + let mut doc_type = context.get_input_document_type()?.into_owned(); + if let Some(index_field_name) = include_array_index { + doc_type.fields.insert( + index_field_name.into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::Long), + description: Some(format!("index of unwound array elements in {name}")), + }, + ); + } + + // If `path` includes a nested_path then the type for the unwound field will be nested + // objects + fn build_nested_types( + context: &mut PipelineTypeContext<'_>, + ultimate_field_type: Type, + parent_object_type: &mut ObjectType, + desired_object_type_name: &str, + field_name: FieldName, + mut rest: impl Iterator, + ) { + match rest.next() { + Some(next_field_name) => { + let object_type_name = context.unique_type_name(desired_object_type_name); + let mut object_type = ObjectType { + fields: Default::default(), + description: None, + }; + build_nested_types( + context, + ultimate_field_type, + &mut object_type, + &format!("{desired_object_type_name}_{next_field_name}"), + next_field_name, + rest, + ); + context.insert_object_type(object_type_name.clone(), object_type); + parent_object_type.fields.insert( + field_name, + ObjectField { + r#type: Type::Object(object_type_name.into()), + description: None, + }, + ); + } + None => { + parent_object_type.fields.insert( + field_name, + ObjectField { + r#type: ultimate_field_type, + description: None, + }, + ); + } + } + } + build_nested_types( + context, + *field_element_type, + &mut doc_type, + desired_object_type_name, + name, + nested_path_iter, + ); + + let object_type_name = context.unique_type_name(desired_object_type_name); + context.insert_object_type(object_type_name.clone(), doc_type); + + Ok(object_type_name) +} + +#[cfg(test)] +mod tests { + use configuration::schema::{ObjectField, ObjectType, Type}; + use mongodb::bson::doc; + use mongodb_support::{ + aggregate::{Pipeline, Selection, Stage}, + BsonScalarType, + }; + use pretty_assertions::assert_eq; + use test_helpers::configuration::mflix_config; + + use crate::native_query::pipeline_type_context::PipelineTypeContext; + + use super::{infer_result_type, infer_type_from_unwind_stage}; + + type Result = anyhow::Result; + + #[test] + fn infers_type_from_documents_stage() -> Result<()> { + let pipeline = Pipeline::new(vec![Stage::Documents(vec![ + doc! { "foo": 1 }, + doc! { "bar": 2 }, + ])]); + let config = mflix_config(); + let pipeline_types = infer_result_type(&config, "documents", None, &pipeline).unwrap(); + let expected = [( + "documents_documents".into(), + ObjectType { + fields: [ + ( + "foo".into(), + ObjectField { + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + description: None, + }, + ), + ( + "bar".into(), + ObjectField { + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + description: None, + }, + ), + ] + .into(), + description: None, + }, + )] + .into(); + let actual = pipeline_types.object_types; + assert_eq!(actual, expected); + Ok(()) + } + + #[test] + fn infers_type_from_replace_with_stage() -> Result<()> { + let pipeline = Pipeline::new(vec![Stage::ReplaceWith(Selection::new(doc! { + "selected_title": "$title" + }))]); + let config = mflix_config(); + let pipeline_types = infer_result_type( + &config, + "movies_selection", + Some(&("movies".into())), + &pipeline, + ) + .unwrap(); + let expected = [( + "movies_selection".into(), + ObjectType { + fields: [( + "selected_title".into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::String), + description: None, + }, + )] + .into(), + description: None, + }, + )] + .into(); + let actual = pipeline_types.object_types; + assert_eq!(actual, expected); + Ok(()) + } + + #[test] + fn infers_type_from_unwind_stage() -> Result<()> { + let config = mflix_config(); + let mut context = PipelineTypeContext::new(&config, None); + context.insert_object_type( + "words_doc".into(), + ObjectType { + fields: [( + "words".into(), + ObjectField { + r#type: Type::ArrayOf(Box::new(Type::Scalar(BsonScalarType::String))), + description: None, + }, + )] + .into(), + description: None, + }, + ); + context.set_stage_doc_type("words_doc".into(), Default::default()); + + let inferred_type_name = infer_type_from_unwind_stage( + &mut context, + "unwind_stage", + "$words", + Some("idx"), + Some(false), + )?; + + assert_eq!( + context + .get_object_type(&inferred_type_name) + .unwrap() + .into_owned(), + ObjectType { + fields: [ + ( + "words".into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::String), + description: None, + } + ), + ( + "idx".into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::Long), + description: Some("index of unwound array elements in words".into()), + } + ), + ] + .into(), + description: None, + } + ); + Ok(()) + } +} diff --git a/crates/cli/src/native_query/mod.rs b/crates/cli/src/native_query/mod.rs new file mode 100644 index 00000000..f25be213 --- /dev/null +++ b/crates/cli/src/native_query/mod.rs @@ -0,0 +1,290 @@ +mod aggregation_expression; +pub mod error; +mod helpers; +mod infer_result_type; +mod pipeline_type_context; +mod reference_shorthand; + +use std::path::{Path, PathBuf}; +use std::process::exit; + +use clap::Subcommand; +use configuration::{ + native_query::NativeQueryRepresentation::Collection, serialized::NativeQuery, Configuration, +}; +use configuration::{read_directory, WithName}; +use mongodb_support::aggregate::Pipeline; +use ndc_models::CollectionName; +use tokio::fs; + +use crate::exit_codes::ExitCode; +use crate::Context; + +use self::error::Result; +use self::infer_result_type::infer_result_type; + +/// Create native queries - custom MongoDB queries that integrate into your data graph +#[derive(Clone, Debug, Subcommand)] +pub enum Command { + /// Create a native query from a JSON file containing an aggregation pipeline + Create { + /// Name that will identify the query in your data graph + #[arg(long, short = 'n', required = true)] + name: String, + + /// Name of the collection that acts as input for the pipeline - omit for a pipeline that does not require input + #[arg(long, short = 'c')] + collection: Option, + + /// Overwrite any existing native query configuration with the same name + #[arg(long, short = 'f')] + force: bool, + + /// Path to a JSON file with an aggregation pipeline + pipeline_path: PathBuf, + }, +} + +pub async fn run(context: &Context, command: Command) -> anyhow::Result<()> { + match command { + Command::Create { + name, + collection, + force, + pipeline_path, + } => { + let configuration = match read_directory(&context.path).await { + Ok(c) => c, + Err(err) => { + eprintln!("Could not read connector configuration - configuration must be initialized before creating native queries.\n\n{err}"); + exit(ExitCode::CouldNotReadConfiguration.into()) + } + }; + eprintln!( + "Read configuration from {}", + &context.path.to_string_lossy() + ); + + let pipeline = match read_pipeline(&pipeline_path).await { + Ok(p) => p, + Err(err) => { + eprintln!("Could not read aggregation pipeline.\n\n{err}"); + exit(ExitCode::CouldNotReadAggregationPipeline.into()) + } + }; + let native_query_path = { + let path = get_native_query_path(context, &name); + if !force && fs::try_exists(&path).await? { + eprintln!( + "A native query named {name} already exists at {}.", + path.to_string_lossy() + ); + eprintln!("Re-run with --force to overwrite."); + exit(ExitCode::RefusedToOverwrite.into()) + } + path + }; + let native_query = + match native_query_from_pipeline(&configuration, &name, collection, pipeline) { + Ok(q) => WithName::named(name, q), + Err(_) => todo!(), + }; + + let native_query_dir = native_query_path + .parent() + .expect("parent directory of native query configuration path"); + if !(fs::try_exists(&native_query_dir).await?) { + fs::create_dir(&native_query_dir).await?; + } + + if let Err(err) = fs::write( + &native_query_path, + serde_json::to_string_pretty(&native_query)?, + ) + .await + { + eprintln!("Error writing native query configuration: {err}"); + exit(ExitCode::ErrorWriting.into()) + }; + eprintln!( + "Wrote native query configuration to {}", + native_query_path.to_string_lossy() + ); + Ok(()) + } + } +} + +async fn read_pipeline(pipeline_path: &Path) -> anyhow::Result { + let input = fs::read(pipeline_path).await?; + let pipeline = serde_json::from_slice(&input)?; + Ok(pipeline) +} + +fn get_native_query_path(context: &Context, name: &str) -> PathBuf { + context + .path + .join(configuration::NATIVE_QUERIES_DIRNAME) + .join(name) + .with_extension("json") +} + +pub fn native_query_from_pipeline( + configuration: &Configuration, + name: &str, + input_collection: Option, + pipeline: Pipeline, +) -> Result { + let pipeline_types = + infer_result_type(configuration, name, input_collection.as_ref(), &pipeline)?; + // TODO: move warnings to `run` function + for warning in pipeline_types.warnings { + println!("warning: {warning}"); + } + Ok(NativeQuery { + representation: Collection, + input_collection, + arguments: Default::default(), // TODO: infer arguments + result_document_type: pipeline_types.result_document_type, + object_types: pipeline_types.object_types, + pipeline: pipeline.into(), + description: None, + }) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use configuration::{ + native_query::NativeQueryRepresentation::Collection, + read_directory, + schema::{ObjectField, ObjectType, Type}, + serialized::NativeQuery, + Configuration, + }; + use mongodb::bson::doc; + use mongodb_support::{ + aggregate::{Accumulator, Pipeline, Selection, Stage}, + BsonScalarType, + }; + use ndc_models::ObjectTypeName; + use pretty_assertions::assert_eq; + + use super::native_query_from_pipeline; + + #[tokio::test] + async fn infers_native_query_from_pipeline() -> Result<()> { + let config = read_configuration().await?; + let pipeline = Pipeline::new(vec![Stage::Documents(vec![ + doc! { "foo": 1 }, + doc! { "bar": 2 }, + ])]); + let native_query = native_query_from_pipeline( + &config, + "selected_title", + Some("movies".into()), + pipeline.clone(), + )?; + + let expected_document_type_name: ObjectTypeName = "selected_title_documents".into(); + + let expected_object_types = [( + expected_document_type_name.clone(), + ObjectType { + fields: [ + ( + "foo".into(), + ObjectField { + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + description: None, + }, + ), + ( + "bar".into(), + ObjectField { + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + description: None, + }, + ), + ] + .into(), + description: None, + }, + )] + .into(); + + let expected = NativeQuery { + representation: Collection, + input_collection: Some("movies".into()), + arguments: Default::default(), + result_document_type: expected_document_type_name, + object_types: expected_object_types, + pipeline: pipeline.into(), + description: None, + }; + + assert_eq!(native_query, expected); + Ok(()) + } + + #[tokio::test] + async fn infers_native_query_from_non_trivial_pipeline() -> Result<()> { + let config = read_configuration().await?; + let pipeline = Pipeline::new(vec![ + Stage::ReplaceWith(Selection::new(doc! { + "title_words": { "$split": ["$title", " "] } + })), + Stage::Unwind { + path: "$title_words".to_string(), + include_array_index: None, + preserve_null_and_empty_arrays: None, + }, + Stage::Group { + key_expression: "$title_words".into(), + accumulators: [("title_count".into(), Accumulator::Count)].into(), + }, + ]); + let native_query = native_query_from_pipeline( + &config, + "title_word_frequency", + Some("movies".into()), + pipeline.clone(), + )?; + + assert_eq!(native_query.input_collection, Some("movies".into())); + assert!(native_query + .result_document_type + .to_string() + .starts_with("title_word_frequency")); + assert_eq!( + native_query + .object_types + .get(&native_query.result_document_type), + Some(&ObjectType { + fields: [ + ( + "_id".into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::String), + description: None, + }, + ), + ( + "title_count".into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::Int), + description: None, + }, + ), + ] + .into(), + description: None, + }) + ); + Ok(()) + } + + async fn read_configuration() -> Result { + read_directory("../../fixtures/hasura/sample_mflix/connector").await + } +} diff --git a/crates/cli/src/native_query/pipeline_type_context.rs b/crates/cli/src/native_query/pipeline_type_context.rs new file mode 100644 index 00000000..8c64839c --- /dev/null +++ b/crates/cli/src/native_query/pipeline_type_context.rs @@ -0,0 +1,175 @@ +#![allow(dead_code)] + +use std::{ + borrow::Cow, + collections::{BTreeMap, HashMap, HashSet}, +}; + +use configuration::{ + schema::{ObjectType, Type}, + Configuration, +}; +use deriving_via::DerivingVia; +use ndc_models::ObjectTypeName; + +use super::error::{Error, Result}; + +type ObjectTypes = BTreeMap; + +#[derive(DerivingVia)] +#[deriving(Copy, Debug, Eq, Hash)] +pub struct TypeVariable(u32); + +/// Information exported from [PipelineTypeContext] after type inference is complete. +#[derive(Clone, Debug)] +pub struct PipelineTypes { + pub result_document_type: ObjectTypeName, + pub object_types: BTreeMap, + pub warnings: Vec, +} + +impl<'a> TryFrom> for PipelineTypes { + type Error = Error; + + fn try_from(context: PipelineTypeContext<'a>) -> Result { + Ok(Self { + result_document_type: context.get_input_document_type_name()?.into(), + object_types: context.object_types.clone(), + warnings: context.warnings, + }) + } +} + +#[derive(Clone, Debug)] +pub struct PipelineTypeContext<'a> { + configuration: &'a Configuration, + + /// Document type for inputs to the pipeline stage being evaluated. At the start of the + /// pipeline this is the document type for the input collection, if there is one. + input_doc_type: Option>, + + /// Object types defined in the process of type inference. [self.input_doc_type] may refer to + /// to a type here, or in [self.configuration.object_types] + object_types: ObjectTypes, + + type_variables: HashMap>, + next_type_variable: u32, + + warnings: Vec, +} + +impl PipelineTypeContext<'_> { + pub fn new( + configuration: &Configuration, + input_collection_document_type: Option, + ) -> PipelineTypeContext<'_> { + PipelineTypeContext { + configuration, + input_doc_type: input_collection_document_type.map(|type_name| { + HashSet::from_iter([Constraint::ConcreteType(Type::Object( + type_name.to_string(), + ))]) + }), + object_types: Default::default(), + type_variables: Default::default(), + next_type_variable: 0, + warnings: Default::default(), + } + } + + pub fn new_type_variable( + &mut self, + constraints: impl IntoIterator, + ) -> TypeVariable { + let variable = TypeVariable(self.next_type_variable); + self.next_type_variable += 1; + self.type_variables + .insert(variable, constraints.into_iter().collect()); + variable + } + + pub fn set_type_variable_constraint(&mut self, variable: TypeVariable, constraint: Constraint) { + let entry = self + .type_variables + .get_mut(&variable) + .expect("unknown type variable"); + entry.insert(constraint); + } + + pub fn insert_object_type(&mut self, name: ObjectTypeName, object_type: ObjectType) { + self.object_types.insert(name, object_type); + } + + pub fn unique_type_name(&self, desired_type_name: &str) -> ObjectTypeName { + let mut counter = 0; + let mut type_name: ObjectTypeName = desired_type_name.into(); + while self.configuration.object_types.contains_key(&type_name) + || self.object_types.contains_key(&type_name) + { + counter += 1; + type_name = format!("{desired_type_name}_{counter}").into(); + } + type_name + } + + pub fn set_stage_doc_type(&mut self, type_name: ObjectTypeName, mut object_types: ObjectTypes) { + self.input_doc_type = Some( + [Constraint::ConcreteType(Type::Object( + type_name.to_string(), + ))] + .into(), + ); + self.object_types.append(&mut object_types); + } + + pub fn set_unknown_stage_doc_type(&mut self, warning: Error) { + self.input_doc_type = Some([].into()); + self.warnings.push(warning); + } + + pub fn get_object_type(&self, name: &ObjectTypeName) -> Option> { + if let Some(object_type) = self.configuration.object_types.get(name) { + let schema_object_type = object_type.clone().into(); + return Some(Cow::Owned(schema_object_type)); + } + if let Some(object_type) = self.object_types.get(name) { + return Some(Cow::Borrowed(object_type)); + } + None + } + + /// Get the input document type for the next stage. Forces to a concrete type, and returns an + /// error if a concrete type cannot be inferred. + pub fn get_input_document_type_name(&self) -> Result<&str> { + match &self.input_doc_type { + None => Err(Error::IncompletePipeline), + Some(constraints) => { + let len = constraints.len(); + let first_constraint = constraints.iter().next(); + if let (1, Some(Constraint::ConcreteType(Type::Object(t)))) = + (len, first_constraint) + { + Ok(t) + } else { + Err(Error::UnableToInferResultType) + } + } + } + } + + pub fn get_input_document_type(&self) -> Result> { + let document_type_name = self.get_input_document_type_name()?.into(); + Ok(self + .get_object_type(&document_type_name) + .expect("if we have an input document type name we should have the object type")) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Constraint { + /// The variable appears in a context with a specific type, and this is it. + ConcreteType(Type), + + /// The variable has the same type as another type variable. + TypeRef(TypeVariable), +} diff --git a/crates/cli/src/native_query/reference_shorthand.rs b/crates/cli/src/native_query/reference_shorthand.rs new file mode 100644 index 00000000..8202567d --- /dev/null +++ b/crates/cli/src/native_query/reference_shorthand.rs @@ -0,0 +1,130 @@ +use ndc_models::FieldName; +use nom::{ + branch::alt, + bytes::complete::{tag, take_while1}, + character::complete::{alpha1, alphanumeric1}, + combinator::{all_consuming, cut, map, opt, recognize}, + multi::{many0, many0_count}, + sequence::{delimited, pair, preceded}, + IResult, +}; + +use super::error::{Error, Result}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Reference { + /// Reference to a variable that is substituted by the connector from GraphQL inputs before + /// sending to MongoDB. For example, `"{{ artist_id }}`. + NativeQueryVariable { + name: String, + type_annotation: Option, + }, + + /// Reference to a variable that is defined as part of the pipeline syntax. May be followed by + /// a dot-separated path to a nested field. For example, `"$$CURRENT.foo.bar"` + PipelineVariable { + name: String, + nested_path: Vec, + }, + + /// Reference to a field of the input document. May be followed by a dot-separated path to + /// a nested field. For example, `"$tomatoes.viewer.rating"` + InputDocumentField { + name: FieldName, + nested_path: Vec, + }, + + /// The expression evaluates to a string - that's all we need to know + String, +} + +pub fn parse_reference_shorthand(input: &str) -> Result { + match reference_shorthand(input) { + Ok((_, r)) => Ok(r), + Err(err) => Err(Error::UnableToParseReferenceShorthand(format!("{err}"))), + } +} + +/// Reference shorthand is a string in an aggregation expression that may evaluate to the value of +/// a field of the input document if the string begins with $, or to a variable if it begins with +/// $$, or may be a plain string. +fn reference_shorthand(input: &str) -> IResult<&str, Reference> { + all_consuming(alt(( + native_query_variable, + pipeline_variable, + input_document_field, + plain_string, + )))(input) +} + +// A native query variable placeholder might be embedded in a larger string. But in that case the +// expression evaluates to a string so we ignore it. +fn native_query_variable(input: &str) -> IResult<&str, Reference> { + let placeholder_content = |input| { + map(take_while1(|c| c != '}' && c != '|'), |content: &str| { + content.trim() + })(input) + }; + let type_annotation = preceded(tag("|"), placeholder_content); + + let (remaining, (name, variable_type)) = delimited( + tag("{{"), + cut(pair(placeholder_content, opt(type_annotation))), + tag("}}"), + )(input)?; + // Since the native_query_variable parser runs inside an `alt`, the use of `cut` commits to + // this branch of the `alt` after successfully parsing the opening "{{" characters. + + let variable = Reference::NativeQueryVariable { + name: name.to_string(), + type_annotation: variable_type.map(ToString::to_string), + }; + Ok((remaining, variable)) +} + +fn pipeline_variable(input: &str) -> IResult<&str, Reference> { + let variable_parser = preceded(tag("$$"), cut(mongodb_variable_name)); + let (remaining, (name, path)) = pair(variable_parser, nested_path)(input)?; + let variable = Reference::PipelineVariable { + name: name.to_string(), + nested_path: path, + }; + Ok((remaining, variable)) +} + +fn input_document_field(input: &str) -> IResult<&str, Reference> { + let field_parser = preceded(tag("$"), cut(mongodb_variable_name)); + let (remaining, (name, path)) = pair(field_parser, nested_path)(input)?; + let field = Reference::InputDocumentField { + name: name.into(), + nested_path: path, + }; + Ok((remaining, field)) +} + +fn mongodb_variable_name(input: &str) -> IResult<&str, &str> { + let first_char = alt((alpha1, tag("_"))); + let succeeding_char = alt((alphanumeric1, tag("_"), non_ascii1)); + recognize(pair(first_char, many0_count(succeeding_char)))(input) +} + +fn nested_path(input: &str) -> IResult<&str, Vec> { + let component_parser = preceded(tag("."), take_while1(|c| c != '.')); + let (remaining, components) = many0(component_parser)(input)?; + Ok(( + remaining, + components.into_iter().map(|c| c.into()).collect(), + )) +} + +fn non_ascii1(input: &str) -> IResult<&str, &str> { + take_while1(is_non_ascii)(input) +} + +fn is_non_ascii(char: char) -> bool { + char as u8 > 127 +} + +fn plain_string(_input: &str) -> IResult<&str, Reference> { + Ok(("", Reference::String)) +} diff --git a/crates/configuration/Cargo.toml b/crates/configuration/Cargo.toml index dd67b71e..264c51d5 100644 --- a/crates/configuration/Cargo.toml +++ b/crates/configuration/Cargo.toml @@ -12,6 +12,7 @@ futures = "^0.3" itertools = { workspace = true } mongodb = { workspace = true } ndc-models = { workspace = true } +ref-cast = { workspace = true } schemars = { workspace = true } serde = { version = "1", features = ["derive"] } serde_json = { version = "1" } diff --git a/crates/configuration/src/lib.rs b/crates/configuration/src/lib.rs index c9c2f971..822aa1fe 100644 --- a/crates/configuration/src/lib.rs +++ b/crates/configuration/src/lib.rs @@ -13,6 +13,10 @@ pub use crate::directory::list_existing_schemas; pub use crate::directory::parse_configuration_options_file; pub use crate::directory::read_directory; pub use crate::directory::write_schema_directory; +pub use crate::directory::{ + CONFIGURATION_OPTIONS_BASENAME, CONFIGURATION_OPTIONS_METADATA, NATIVE_MUTATIONS_DIRNAME, + NATIVE_QUERIES_DIRNAME, SCHEMA_DIRNAME, +}; pub use crate::mongo_scalar_type::MongoScalarType; pub use crate::serialized::Schema; pub use crate::with_name::{WithName, WithNameRef}; diff --git a/crates/configuration/src/native_query.rs b/crates/configuration/src/native_query.rs index e8986bb6..2cf875f4 100644 --- a/crates/configuration/src/native_query.rs +++ b/crates/configuration/src/native_query.rs @@ -5,7 +5,7 @@ use ndc_models as ndc; use ndc_query_plan as plan; use plan::QueryPlanError; use schemars::JsonSchema; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::serialized; @@ -39,7 +39,7 @@ impl NativeQuery { } } -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Hash, JsonSchema)] +#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq, Eq, Hash, JsonSchema)] #[serde(rename_all = "camelCase")] pub enum NativeQueryRepresentation { Collection, diff --git a/crates/configuration/src/schema/mod.rs b/crates/configuration/src/schema/mod.rs index 3476e75f..55a9214c 100644 --- a/crates/configuration/src/schema/mod.rs +++ b/crates/configuration/src/schema/mod.rs @@ -1,11 +1,12 @@ use std::collections::BTreeMap; +use ref_cast::RefCast as _; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use mongodb_support::BsonScalarType; -use crate::{WithName, WithNameRef}; +use crate::{MongoScalarType, WithName, WithNameRef}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] @@ -18,7 +19,7 @@ pub struct Collection { } /// The type of values that a column, field, or argument may take. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub enum Type { /// Any BSON value, represented as Extended JSON. @@ -100,6 +101,30 @@ impl From for ndc_models::Type { } } +impl From for Type { + fn from(t: ndc_models::Type) -> Self { + match t { + ndc_models::Type::Named { name } => { + let scalar_type_name = ndc_models::ScalarTypeName::ref_cast(&name); + match MongoScalarType::try_from(scalar_type_name) { + Ok(MongoScalarType::Bson(scalar_type)) => Type::Scalar(scalar_type), + Ok(MongoScalarType::ExtendedJSON) => Type::ExtendedJSON, + Err(_) => Type::Object(name.to_string()), + } + } + ndc_models::Type::Nullable { underlying_type } => { + Type::Nullable(Box::new(Self::from(*underlying_type))) + } + ndc_models::Type::Array { element_type } => { + Type::ArrayOf(Box::new(Self::from(*element_type))) + } + ndc_models::Type::Predicate { object_type_name } => { + Type::Predicate { object_type_name } + } + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub struct ObjectType { @@ -139,6 +164,19 @@ impl From for ndc_models::ObjectType { } } +impl From for ObjectType { + fn from(object_type: ndc_models::ObjectType) -> Self { + ObjectType { + description: object_type.description, + fields: object_type + .fields + .into_iter() + .map(|(name, field)| (name, field.into())) + .collect(), + } + } +} + /// Information about an object type field. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] @@ -169,3 +207,12 @@ impl From for ndc_models::ObjectField { } } } + +impl From for ObjectField { + fn from(field: ndc_models::ObjectField) -> Self { + ObjectField { + description: field.description, + r#type: field.r#type.into(), + } + } +} diff --git a/crates/configuration/src/serialized/native_query.rs b/crates/configuration/src/serialized/native_query.rs index 11ff4b87..9fde303f 100644 --- a/crates/configuration/src/serialized/native_query.rs +++ b/crates/configuration/src/serialized/native_query.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use mongodb::bson; use schemars::JsonSchema; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::{ native_query::NativeQueryRepresentation, @@ -11,7 +11,7 @@ use crate::{ /// Define an arbitrary MongoDB aggregation pipeline that can be referenced in your data graph. For /// details on aggregation pipelines see https://www.mongodb.com/docs/manual/core/aggregation-pipeline/ -#[derive(Clone, Debug, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Deserialize, Serialize, JsonSchema, PartialEq)] #[serde(rename_all = "camelCase")] pub struct NativeQuery { /// Representation may be either "collection" or "function". If you choose "collection" then diff --git a/crates/mongodb-agent-common/src/mongodb/collection.rs b/crates/mongodb-agent-common/src/mongodb/collection.rs index 090dc66a..db759d1d 100644 --- a/crates/mongodb-agent-common/src/mongodb/collection.rs +++ b/crates/mongodb-agent-common/src/mongodb/collection.rs @@ -6,13 +6,12 @@ use mongodb::{ options::{AggregateOptions, FindOptions}, Collection, }; +use mongodb_support::aggregate::Pipeline; use serde::de::DeserializeOwned; #[cfg(test)] use mockall::automock; -use super::Pipeline; - #[cfg(test)] use super::test_helpers::MockCursor; diff --git a/crates/mongodb-agent-common/src/mongodb/database.rs b/crates/mongodb-agent-common/src/mongodb/database.rs index ce56a06f..16be274b 100644 --- a/crates/mongodb-agent-common/src/mongodb/database.rs +++ b/crates/mongodb-agent-common/src/mongodb/database.rs @@ -1,11 +1,12 @@ use async_trait::async_trait; use futures_util::Stream; use mongodb::{bson::Document, error::Error, options::AggregateOptions, Database}; +use mongodb_support::aggregate::Pipeline; #[cfg(test)] use mockall::automock; -use super::{CollectionTrait, Pipeline}; +use super::CollectionTrait; #[cfg(test)] use super::MockCollectionTrait; diff --git a/crates/mongodb-agent-common/src/mongodb/mod.rs b/crates/mongodb-agent-common/src/mongodb/mod.rs index d1a7c8c4..361dbf89 100644 --- a/crates/mongodb-agent-common/src/mongodb/mod.rs +++ b/crates/mongodb-agent-common/src/mongodb/mod.rs @@ -1,18 +1,13 @@ -mod accumulator; mod collection; mod database; -mod pipeline; pub mod sanitize; mod selection; -mod sort_document; -mod stage; #[cfg(test)] pub mod test_helpers; pub use self::{ - accumulator::Accumulator, collection::CollectionTrait, database::DatabaseTrait, - pipeline::Pipeline, selection::Selection, sort_document::SortDocument, stage::Stage, + collection::CollectionTrait, database::DatabaseTrait, selection::selection_from_query_request, }; // MockCollectionTrait is generated by automock when the test flag is active. diff --git a/crates/mongodb-agent-common/src/mongodb/selection.rs b/crates/mongodb-agent-common/src/mongodb/selection.rs index 0307533e..84c166bf 100644 --- a/crates/mongodb-agent-common/src/mongodb/selection.rs +++ b/crates/mongodb-agent-common/src/mongodb/selection.rs @@ -1,7 +1,7 @@ use indexmap::IndexMap; -use mongodb::bson::{self, doc, Bson, Document}; +use mongodb::bson::{doc, Bson, Document}; +use mongodb_support::aggregate::Selection; use ndc_models::FieldName; -use serde::{Deserialize, Serialize}; use crate::{ interface_types::MongoAgentError, @@ -10,33 +10,18 @@ use crate::{ query::column_ref::ColumnRef, }; -/// Wraps a BSON document that represents a MongoDB "expression" that constructs a document based -/// on the output of a previous aggregation pipeline stage. A Selection value is intended to be -/// used as the argument to a $replaceWith pipeline stage. -/// -/// When we compose pipelines, we can pair each Pipeline with a Selection that extracts the data we -/// want, in the format we want it to provide to HGE. We can collect Selection values and merge -/// them to form one stage after all of the composed pipelines. -#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] -#[serde(transparent)] -pub struct Selection(pub bson::Document); - -impl Selection { - pub fn from_doc(doc: bson::Document) -> Self { - Selection(doc) - } - - pub fn from_query_request(query_request: &QueryPlan) -> Result { - // let fields = (&query_request.query.fields).flatten().unwrap_or_default(); - let empty_map = IndexMap::new(); - let fields = if let Some(fs) = &query_request.query.fields { - fs - } else { - &empty_map - }; - let doc = from_query_request_helper(None, fields)?; - Ok(Selection(doc)) - } +pub fn selection_from_query_request( + query_request: &QueryPlan, +) -> Result { + // let fields = (&query_request.query.fields).flatten().unwrap_or_default(); + let empty_map = IndexMap::new(); + let fields = if let Some(fs) = &query_request.query.fields { + fs + } else { + &empty_map + }; + let doc = from_query_request_helper(None, fields)?; + Ok(Selection::new(doc)) } fn from_query_request_helper( @@ -188,27 +173,6 @@ fn nested_column_reference<'a>( } } -/// The extend implementation provides a shallow merge. -impl Extend<(String, Bson)> for Selection { - fn extend>(&mut self, iter: T) { - self.0.extend(iter); - } -} - -impl From for bson::Document { - fn from(value: Selection) -> Self { - value.0 - } -} - -// This won't fail, but it might in the future if we add some sort of validation or parsing. -impl TryFrom for Selection { - type Error = anyhow::Error; - fn try_from(value: bson::Document) -> Result { - Ok(Selection(value)) - } -} - #[cfg(test)] mod tests { use configuration::Configuration; @@ -220,9 +184,7 @@ mod tests { }; use pretty_assertions::assert_eq; - use crate::mongo_query_plan::MongoConfiguration; - - use super::Selection; + use crate::{mongo_query_plan::MongoConfiguration, mongodb::selection_from_query_request}; #[test] fn calculates_selection_for_query_request() -> Result<(), anyhow::Error> { @@ -250,7 +212,7 @@ mod tests { let query_plan = plan_for_query_request(&foo_config(), query_request)?; - let selection = Selection::from_query_request(&query_plan)?; + let selection = selection_from_query_request(&query_plan)?; assert_eq!( Into::::into(selection), doc! { @@ -342,7 +304,7 @@ mod tests { // twice (once with the key `class_students`, and then with the key `class_students_0`). // This is because the queries on the two relationships have different scope names. The // query would work with just one lookup. Can we do that optimization? - let selection = Selection::from_query_request(&query_plan)?; + let selection = selection_from_query_request(&query_plan)?; assert_eq!( Into::::into(selection), doc! { diff --git a/crates/mongodb-agent-common/src/query/execute_query_request.rs b/crates/mongodb-agent-common/src/query/execute_query_request.rs index d1193ebc..aa1b4551 100644 --- a/crates/mongodb-agent-common/src/query/execute_query_request.rs +++ b/crates/mongodb-agent-common/src/query/execute_query_request.rs @@ -1,6 +1,7 @@ use futures::Stream; use futures_util::TryStreamExt as _; use mongodb::bson; +use mongodb_support::aggregate::Pipeline; use ndc_models::{QueryRequest, QueryResponse}; use ndc_query_plan::plan_for_query_request; use tracing::{instrument, Instrument}; @@ -9,7 +10,7 @@ use super::{pipeline::pipeline_for_query_request, response::serialize_query_resp use crate::{ interface_types::MongoAgentError, mongo_query_plan::{MongoConfiguration, QueryPlan}, - mongodb::{CollectionTrait as _, DatabaseTrait, Pipeline}, + mongodb::{CollectionTrait as _, DatabaseTrait}, query::QueryTarget, }; diff --git a/crates/mongodb-agent-common/src/query/foreach.rs b/crates/mongodb-agent-common/src/query/foreach.rs index ce783864..4995eb40 100644 --- a/crates/mongodb-agent-common/src/query/foreach.rs +++ b/crates/mongodb-agent-common/src/query/foreach.rs @@ -1,6 +1,7 @@ use anyhow::anyhow; use itertools::Itertools as _; use mongodb::bson::{self, doc, Bson}; +use mongodb_support::aggregate::{Pipeline, Selection, Stage}; use ndc_query_plan::VariableSet; use super::pipeline::pipeline_for_non_foreach; @@ -8,12 +9,8 @@ use super::query_level::QueryLevel; use super::query_variable_name::query_variable_name; use super::serialization::json_to_bson; use super::QueryTarget; +use crate::interface_types::MongoAgentError; use crate::mongo_query_plan::{MongoConfiguration, QueryPlan, Type, VariableTypes}; -use crate::mongodb::Selection; -use crate::{ - interface_types::MongoAgentError, - mongodb::{Pipeline, Stage}, -}; type Result = std::result::Result; @@ -62,7 +59,7 @@ pub fn pipeline_for_foreach( "rows": "$query" } }; - let selection_stage = Stage::ReplaceWith(Selection(selection)); + let selection_stage = Stage::ReplaceWith(Selection::new(selection)); Ok(Pipeline { stages: vec![variable_sets_stage, lookup_stage, selection_stage], diff --git a/crates/mongodb-agent-common/src/query/make_sort.rs b/crates/mongodb-agent-common/src/query/make_sort.rs index e2de1d35..7adad5a8 100644 --- a/crates/mongodb-agent-common/src/query/make_sort.rs +++ b/crates/mongodb-agent-common/src/query/make_sort.rs @@ -2,12 +2,13 @@ use std::{collections::BTreeMap, iter::once}; use itertools::join; use mongodb::bson::bson; +use mongodb_support::aggregate::{SortDocument, Stage}; use ndc_models::OrderDirection; use crate::{ interface_types::MongoAgentError, mongo_query_plan::{OrderBy, OrderByTarget}, - mongodb::{sanitize::escape_invalid_variable_chars, SortDocument, Stage}, + mongodb::sanitize::escape_invalid_variable_chars, }; use super::column_ref::ColumnRef; @@ -112,11 +113,12 @@ fn safe_alias(target: &OrderByTarget) -> Result { #[cfg(test)] mod tests { use mongodb::bson::doc; + use mongodb_support::aggregate::SortDocument; use ndc_models::{FieldName, OrderDirection}; use ndc_query_plan::OrderByElement; use pretty_assertions::assert_eq; - use crate::{mongo_query_plan::OrderBy, mongodb::SortDocument, query::column_ref::ColumnRef}; + use crate::{mongo_query_plan::OrderBy, query::column_ref::ColumnRef}; use super::make_sort; diff --git a/crates/mongodb-agent-common/src/query/native_query.rs b/crates/mongodb-agent-common/src/query/native_query.rs index 946b5eea..b5a7a4c2 100644 --- a/crates/mongodb-agent-common/src/query/native_query.rs +++ b/crates/mongodb-agent-common/src/query/native_query.rs @@ -3,12 +3,12 @@ use std::collections::BTreeMap; use configuration::native_query::NativeQuery; use itertools::Itertools as _; use mongodb::bson::Bson; +use mongodb_support::aggregate::{Pipeline, Stage}; use ndc_models::ArgumentName; use crate::{ interface_types::MongoAgentError, mongo_query_plan::{Argument, MongoConfiguration, QueryPlan}, - mongodb::{Pipeline, Stage}, procedure::{interpolated_command, ProcedureError}, }; diff --git a/crates/mongodb-agent-common/src/query/pipeline.rs b/crates/mongodb-agent-common/src/query/pipeline.rs index 4d72bf26..a831d923 100644 --- a/crates/mongodb-agent-common/src/query/pipeline.rs +++ b/crates/mongodb-agent-common/src/query/pipeline.rs @@ -2,13 +2,14 @@ use std::collections::BTreeMap; use itertools::Itertools; use mongodb::bson::{self, doc, Bson}; +use mongodb_support::aggregate::{Accumulator, Pipeline, Selection, Stage}; use tracing::instrument; use crate::{ aggregation_function::AggregationFunction, interface_types::MongoAgentError, mongo_query_plan::{Aggregate, MongoConfiguration, Query, QueryPlan}, - mongodb::{sanitize::get_field, Accumulator, Pipeline, Selection, Stage}, + mongodb::{sanitize::get_field, selection_from_query_request}, }; use super::{ @@ -116,15 +117,18 @@ pub fn pipeline_for_fields_facet( .. } = &query_plan.query; - let mut selection = Selection::from_query_request(query_plan)?; + let mut selection = selection_from_query_request(query_plan)?; if query_level != QueryLevel::Top { // Queries higher up the chain might need to reference relationships from this query. So we // forward relationship arrays if this is not the top-level query. for relationship_key in relationships.keys() { - selection.0.insert( - relationship_key.to_owned(), - get_field(relationship_key.as_str()), - ); + selection = selection.try_map_document(|mut doc| { + doc.insert( + relationship_key.to_owned(), + get_field(relationship_key.as_str()), + ); + doc + })?; } } @@ -209,7 +213,7 @@ fn facet_pipelines_for_query( _ => None, }; - let selection = Selection( + let selection = Selection::new( [select_aggregates, select_rows] .into_iter() .flatten() diff --git a/crates/mongodb-agent-common/src/query/relations.rs b/crates/mongodb-agent-common/src/query/relations.rs index f909627f..7b634ed6 100644 --- a/crates/mongodb-agent-common/src/query/relations.rs +++ b/crates/mongodb-agent-common/src/query/relations.rs @@ -2,16 +2,13 @@ use std::collections::BTreeMap; use itertools::Itertools as _; use mongodb::bson::{doc, Bson, Document}; +use mongodb_support::aggregate::{Pipeline, Stage}; use ndc_query_plan::Scope; use crate::mongo_query_plan::{MongoConfiguration, Query, QueryPlan}; use crate::mongodb::sanitize::safe_name; -use crate::mongodb::Pipeline; use crate::query::column_ref::name_from_scope; -use crate::{ - interface_types::MongoAgentError, - mongodb::{sanitize::variable, Stage}, -}; +use crate::{interface_types::MongoAgentError, mongodb::sanitize::variable}; use super::pipeline::pipeline_for_non_foreach; use super::query_level::QueryLevel; diff --git a/crates/mongodb-agent-common/src/state.rs b/crates/mongodb-agent-common/src/state.rs index 7875c7ab..07fae77d 100644 --- a/crates/mongodb-agent-common/src/state.rs +++ b/crates/mongodb-agent-common/src/state.rs @@ -25,13 +25,18 @@ impl ConnectorState { pub async fn try_init_state() -> Result> { // Splitting this out of the `Connector` impl makes error translation easier let database_uri = env::var(DATABASE_URI_ENV_VAR)?; - try_init_state_from_uri(&database_uri).await + let state = try_init_state_from_uri(Some(&database_uri)).await?; + Ok(state) } pub async fn try_init_state_from_uri( - database_uri: &str, -) -> Result> { - let client = get_mongodb_client(database_uri).await?; + database_uri: Option<&impl AsRef>, +) -> anyhow::Result { + let database_uri = database_uri.ok_or(anyhow!( + "Missing environment variable {}", + DATABASE_URI_ENV_VAR + ))?; + let client = get_mongodb_client(database_uri.as_ref()).await?; let database_name = match client.default_database() { Some(database) => Ok(database.name().to_owned()), None => Err(anyhow!( diff --git a/crates/mongodb-agent-common/src/test_helpers.rs b/crates/mongodb-agent-common/src/test_helpers.rs index cc78a049..c8cd2ccd 100644 --- a/crates/mongodb-agent-common/src/test_helpers.rs +++ b/crates/mongodb-agent-common/src/test_helpers.rs @@ -161,36 +161,5 @@ pub fn chinook_relationships() -> BTreeMap { /// Configuration for a MongoDB database that resembles MongoDB's sample_mflix test data set. pub fn mflix_config() -> MongoConfiguration { - MongoConfiguration(Configuration { - collections: [collection("comments"), collection("movies")].into(), - object_types: [ - ( - "comments".into(), - object_type([ - ("_id", named_type("ObjectId")), - ("movie_id", named_type("ObjectId")), - ("name", named_type("String")), - ]), - ), - ( - "credits".into(), - object_type([("director", named_type("String"))]), - ), - ( - "movies".into(), - object_type([ - ("_id", named_type("ObjectId")), - ("credits", named_type("credits")), - ("title", named_type("String")), - ("year", named_type("Int")), - ]), - ), - ] - .into(), - functions: Default::default(), - procedures: Default::default(), - native_mutations: Default::default(), - native_queries: Default::default(), - options: Default::default(), - }) + MongoConfiguration(test_helpers::configuration::mflix_config()) } diff --git a/crates/mongodb-support/Cargo.toml b/crates/mongodb-support/Cargo.toml index a3718e2c..95ca3c3b 100644 --- a/crates/mongodb-support/Cargo.toml +++ b/crates/mongodb-support/Cargo.toml @@ -4,6 +4,7 @@ edition = "2021" version.workspace = true [dependencies] +anyhow = "1" enum-iterator = "^2.0.0" indexmap = { workspace = true } mongodb = { workspace = true } @@ -11,6 +12,3 @@ schemars = "^0.8.12" serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "1" - -[dev-dependencies] -anyhow = "1" diff --git a/crates/mongodb-agent-common/src/mongodb/accumulator.rs b/crates/mongodb-support/src/aggregate/accumulator.rs similarity index 100% rename from crates/mongodb-agent-common/src/mongodb/accumulator.rs rename to crates/mongodb-support/src/aggregate/accumulator.rs diff --git a/crates/mongodb-support/src/aggregate/mod.rs b/crates/mongodb-support/src/aggregate/mod.rs new file mode 100644 index 00000000..dfab9856 --- /dev/null +++ b/crates/mongodb-support/src/aggregate/mod.rs @@ -0,0 +1,11 @@ +mod accumulator; +mod pipeline; +mod selection; +mod sort_document; +mod stage; + +pub use self::accumulator::Accumulator; +pub use self::pipeline::Pipeline; +pub use self::selection::Selection; +pub use self::sort_document::SortDocument; +pub use self::stage::Stage; diff --git a/crates/mongodb-agent-common/src/mongodb/pipeline.rs b/crates/mongodb-support/src/aggregate/pipeline.rs similarity index 73% rename from crates/mongodb-agent-common/src/mongodb/pipeline.rs rename to crates/mongodb-support/src/aggregate/pipeline.rs index 3b728477..0faae2ff 100644 --- a/crates/mongodb-agent-common/src/mongodb/pipeline.rs +++ b/crates/mongodb-support/src/aggregate/pipeline.rs @@ -1,10 +1,12 @@ +use std::{borrow::Borrow, ops::Deref}; + use mongodb::bson; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use super::stage::Stage; /// Aggregation Pipeline -#[derive(Clone, Debug, PartialEq, Serialize)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(transparent)] pub struct Pipeline { pub stages: Vec, @@ -32,6 +34,26 @@ impl Pipeline { } } +impl AsRef<[Stage]> for Pipeline { + fn as_ref(&self) -> &[Stage] { + &self.stages + } +} + +impl Borrow<[Stage]> for Pipeline { + fn borrow(&self) -> &[Stage] { + &self.stages + } +} + +impl Deref for Pipeline { + type Target = [Stage]; + + fn deref(&self) -> &Self::Target { + &self.stages + } +} + /// This impl allows passing a [Pipeline] as the first argument to [mongodb::Collection::aggregate]. impl IntoIterator for Pipeline { type Item = bson::Document; @@ -57,3 +79,9 @@ impl FromIterator for Pipeline { } } } + +impl From for Vec { + fn from(value: Pipeline) -> Self { + value.into_iter().collect() + } +} diff --git a/crates/mongodb-support/src/aggregate/selection.rs b/crates/mongodb-support/src/aggregate/selection.rs new file mode 100644 index 00000000..faa04b0d --- /dev/null +++ b/crates/mongodb-support/src/aggregate/selection.rs @@ -0,0 +1,57 @@ +use mongodb::bson::{self, Bson}; +use serde::{Deserialize, Serialize}; + +/// Wraps a BSON document that represents a MongoDB "expression" that constructs a document based +/// on the output of a previous aggregation pipeline stage. A Selection value is intended to be +/// used as the argument to a $replaceWith pipeline stage. +/// +/// When we compose pipelines, we can pair each Pipeline with a Selection that extracts the data we +/// want, in the format we want it to provide to HGE. We can collect Selection values and merge +/// them to form one stage after all of the composed pipelines. +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +#[serde(transparent)] +pub struct Selection(bson::Document); + +impl Selection { + pub fn new(doc: bson::Document) -> Self { + Self(doc) + } + + /// Transform the contained BSON document in a callback. This may return an error on invariant + /// violations in the future. + pub fn try_map_document(self, callback: F) -> Result + where + F: FnOnce(bson::Document) -> bson::Document, + { + let doc = self.into(); + let updated_doc = callback(doc); + Ok(Self::new(updated_doc)) + } +} + +/// The extend implementation provides a shallow merge. +impl Extend<(String, Bson)> for Selection { + fn extend>(&mut self, iter: T) { + self.0.extend(iter); + } +} + +impl From for bson::Document { + fn from(value: Selection) -> Self { + value.0 + } +} + +impl<'a> From<&'a Selection> for &'a bson::Document { + fn from(value: &'a Selection) -> Self { + &value.0 + } +} + +// This won't fail, but it might in the future if we add some sort of validation or parsing. +impl TryFrom for Selection { + type Error = anyhow::Error; + fn try_from(value: bson::Document) -> Result { + Ok(Selection(value)) + } +} diff --git a/crates/mongodb-agent-common/src/mongodb/sort_document.rs b/crates/mongodb-support/src/aggregate/sort_document.rs similarity index 100% rename from crates/mongodb-agent-common/src/mongodb/sort_document.rs rename to crates/mongodb-support/src/aggregate/sort_document.rs diff --git a/crates/mongodb-agent-common/src/mongodb/stage.rs b/crates/mongodb-support/src/aggregate/stage.rs similarity index 85% rename from crates/mongodb-agent-common/src/mongodb/stage.rs rename to crates/mongodb-support/src/aggregate/stage.rs index 87dc51bb..a604fd40 100644 --- a/crates/mongodb-agent-common/src/mongodb/stage.rs +++ b/crates/mongodb-support/src/aggregate/stage.rs @@ -1,15 +1,15 @@ use std::collections::BTreeMap; use mongodb::bson; -use serde::Serialize; +use serde::{Deserialize, Serialize}; -use super::{accumulator::Accumulator, pipeline::Pipeline, Selection, SortDocument}; +use super::{Accumulator, Pipeline, Selection, SortDocument}; /// Aggergation Pipeline Stage. This is a work-in-progress - we are adding enum variants to match /// MongoDB pipeline stage types as we need them in this app. For documentation on all stage types /// see, /// https://www.mongodb.com/docs/manual/reference/operator/aggregation-pipeline/#std-label-aggregation-pipeline-operator-reference -#[derive(Clone, Debug, PartialEq, Serialize)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] pub enum Stage { /// Adds new fields to documents. $addFields outputs documents that contain all existing fields /// from the input documents and newly added fields. @@ -156,6 +156,32 @@ pub enum Stage { #[serde(rename = "$replaceWith")] ReplaceWith(Selection), + /// Deconstructs an array field from the input documents to output a document for each element. + /// Each output document is the input document with the value of the array field replaced by + /// the element. + /// + /// See https://www.mongodb.com/docs/manual/reference/operator/aggregation/unwind/ + #[serde(rename = "$unwind", rename_all = "camelCase")] + Unwind { + /// Field path to an array field. To specify a field path, prefix the field name with + /// a dollar sign $ and enclose in quotes. + path: String, + + /// Optional. The name of a new field to hold the array index of the element. The name + /// cannot start with a dollar sign $. + #[serde(default, skip_serializing_if = "Option::is_none")] + include_array_index: Option, + + /// Optional. + /// + /// - If true, if the path is null, missing, or an empty array, $unwind outputs the document. + /// - If false, if path is null, missing, or an empty array, $unwind does not output a document. + /// + /// The default value is false. + #[serde(default, skip_serializing_if = "Option::is_none")] + preserve_null_and_empty_arrays: Option, + }, + /// For cases where we receive pipeline stages from an external source, such as a native query, /// and we don't want to attempt to parse it we store the stage BSON document unaltered. #[serde(untagged)] diff --git a/crates/mongodb-support/src/bson_type.rs b/crates/mongodb-support/src/bson_type.rs index 5024a2cf..dd1e63ef 100644 --- a/crates/mongodb-support/src/bson_type.rs +++ b/crates/mongodb-support/src/bson_type.rs @@ -80,7 +80,7 @@ impl<'de> Deserialize<'de> for BsonType { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Sequence, Serialize, Deserialize, JsonSchema)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Sequence, Serialize, Deserialize, JsonSchema)] #[serde(try_from = "BsonType", rename_all = "camelCase")] pub enum BsonScalarType { // numeric diff --git a/crates/mongodb-support/src/lib.rs b/crates/mongodb-support/src/lib.rs index 2f45f8de..f8113b81 100644 --- a/crates/mongodb-support/src/lib.rs +++ b/crates/mongodb-support/src/lib.rs @@ -1,3 +1,4 @@ +pub mod aggregate; pub mod align; mod bson_type; pub mod error; diff --git a/crates/test-helpers/src/configuration.rs b/crates/test-helpers/src/configuration.rs new file mode 100644 index 00000000..d125fc6a --- /dev/null +++ b/crates/test-helpers/src/configuration.rs @@ -0,0 +1,38 @@ +use configuration::Configuration; +use ndc_test_helpers::{collection, named_type, object_type}; + +/// Configuration for a MongoDB database that resembles MongoDB's sample_mflix test data set. +pub fn mflix_config() -> Configuration { + Configuration { + collections: [collection("comments"), collection("movies")].into(), + object_types: [ + ( + "comments".into(), + object_type([ + ("_id", named_type("ObjectId")), + ("movie_id", named_type("ObjectId")), + ("name", named_type("String")), + ]), + ), + ( + "credits".into(), + object_type([("director", named_type("String"))]), + ), + ( + "movies".into(), + object_type([ + ("_id", named_type("ObjectId")), + ("credits", named_type("credits")), + ("title", named_type("String")), + ("year", named_type("Int")), + ]), + ), + ] + .into(), + functions: Default::default(), + procedures: Default::default(), + native_mutations: Default::default(), + native_queries: Default::default(), + options: Default::default(), + } +} diff --git a/crates/test-helpers/src/lib.rs b/crates/test-helpers/src/lib.rs index e9ac03ea..d77f5c81 100644 --- a/crates/test-helpers/src/lib.rs +++ b/crates/test-helpers/src/lib.rs @@ -1,6 +1,7 @@ pub mod arb_bson; mod arb_plan_type; pub mod arb_type; +pub mod configuration; use enum_iterator::Sequence as _; use mongodb_support::ExtendedJsonMode; diff --git a/fixtures/hasura/README.md b/fixtures/hasura/README.md index 45f5b3f8..cb31e000 100644 --- a/fixtures/hasura/README.md +++ b/fixtures/hasura/README.md @@ -32,11 +32,11 @@ this repo. The plugin binary is provided by the Nix dev shell. Use these commands: ```sh -$ mongodb-cli-plugin --connection-uri mongodb://localhost/sample_mflix --context-path sample_mflix/connector/ update +$ nix run .#mongodb-cli-plugin -- --connection-uri mongodb://localhost/sample_mflix --context-path sample_mflix/connector/ update -$ mongodb-cli-plugin --connection-uri mongodb://localhost/chinook --context-path chinook/connector/ update +$ nix run .#mongodb-cli-plugin -- --connection-uri mongodb://localhost/chinook --context-path chinook/connector/ update -$ mongodb-cli-plugin --connection-uri mongodb://localhost/test_cases --context-path test_cases/connector/ update +$ nix run .#mongodb-cli-plugin -- --connection-uri mongodb://localhost/test_cases --context-path test_cases/connector/ update ``` Update Hasura metadata based on connector configuration diff --git a/flake.nix b/flake.nix index f0056bc3..b5c2756b 100644 --- a/flake.nix +++ b/flake.nix @@ -210,7 +210,6 @@ ddn just mongosh - mongodb-cli-plugin pkg-config ]; };