diff --git a/crates/bench/benches/subscription.rs b/crates/bench/benches/subscription.rs index 18ccfd2a86..d1fafba393 100644 --- a/crates/bench/benches/subscription.rs +++ b/crates/bench/benches/subscription.rs @@ -106,10 +106,14 @@ fn eval(c: &mut Criterion) { c.bench_function(name, |b| { let tx = raw.db.begin_tx(Workload::Subscribe); let auth = AuthCtx::for_testing(); - let schema_viewer = &SchemaViewer::new(&raw.db, &tx, &auth); + let schema_viewer = &SchemaViewer::new(&tx, &auth); let plan = SubscribePlan::compile(sql, schema_viewer).unwrap(); - b.iter(|| drop(black_box(plan.execute_bsatn(&tx)))) + b.iter(|| { + drop(black_box( + plan.collect_table_update::(Compression::None, &tx), + )) + }) }); }; diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index fb646581a7..1e3cb38f6f 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -12,7 +12,9 @@ use crate::util::prometheus_handle::IntGaugeExt; use crate::worker_metrics::WORKER_METRICS; use derive_more::From; use futures::prelude::*; -use spacetimedb_client_api_messages::websocket::{CallReducerFlags, Compression, FormatSwitch}; +use spacetimedb_client_api_messages::websocket::{ + BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, WebsocketFormat, +}; use spacetimedb_lib::identity::RequestId; use tokio::sync::{mpsc, oneshot, watch}; use tokio::task::AbortHandle; @@ -294,26 +296,41 @@ impl ClientConnection { .unwrap() } - pub fn one_off_query(&self, query: &str, message_id: &[u8], timer: Instant) -> Result<(), anyhow::Error> { - let result = self.module.one_off_query(self.id.identity, query.to_owned()); + pub fn one_off_query_json(&self, query: &str, message_id: &[u8], timer: Instant) -> Result<(), anyhow::Error> { + let response = self.one_off_query::(query, message_id, timer); + self.send_message(response)?; + Ok(()) + } + + pub fn one_off_query_bsatn(&self, query: &str, message_id: &[u8], timer: Instant) -> Result<(), anyhow::Error> { + let response = self.one_off_query::(query, message_id, timer); + self.send_message(response)?; + Ok(()) + } + + fn one_off_query( + &self, + query: &str, + message_id: &[u8], + timer: Instant, + ) -> OneOffQueryResponseMessage { + let result = self.module.one_off_query::(self.id.identity, query.to_owned()); let message_id = message_id.to_owned(); let total_host_execution_duration = timer.elapsed().as_micros() as u64; - let response = match result { + match result { Ok(results) => OneOffQueryResponseMessage { message_id, error: None, - results, + results: vec![results], total_host_execution_duration, }, Err(err) => OneOffQueryResponseMessage { message_id, error: Some(format!("{}", err)), - results: Vec::new(), + results: vec![], total_host_execution_duration, }, - }; - self.send_message(response)?; - Ok(()) + } } pub async fn disconnect(self) { diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index af59b10160..61cec87df6 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -1,5 +1,5 @@ use super::messages::{SubscriptionUpdateMessage, SwitchedServerMessage, ToProtocol, TransactionUpdateMessage}; -use super::{ClientConnection, DataMessage}; +use super::{ClientConnection, DataMessage, Protocol}; use crate::energy::EnergyQuanta; use crate::execution_context::WorkloadType; use crate::host::module_host::{EventStatus, ModuleEvent, ModuleFunctionCall}; @@ -91,7 +91,10 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst query_string: query, message_id, }) => { - let res = client.one_off_query(&query, &message_id, timer); + let res = match client.config.protocol { + Protocol::Binary => client.one_off_query_bsatn(&query, &message_id, timer), + Protocol::Text => client.one_off_query_json(&query, &message_id, timer), + }; WORKER_METRICS .request_round_trip .with_label_values(&WorkloadType::Sql, &address, "") diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index eacc330ebf..26d794adfa 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -5,14 +5,13 @@ use crate::host::ArgsTuple; use crate::messages::websocket as ws; use derive_more::From; use spacetimedb_client_api_messages::websocket::{ - BsatnFormat, Compression, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI, - SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, + BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat, + SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, }; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; use spacetimedb_lib::Address; use spacetimedb_sats::bsatn; -use spacetimedb_vm::relation::MemTable; use std::sync::Arc; use std::time::Instant; @@ -63,7 +62,8 @@ pub fn serialize(msg: impl ToProtocol, config: #[derive(Debug, From)] pub enum SerializableMessage { - Query(OneOffQueryResponseMessage), + QueryBinary(OneOffQueryResponseMessage), + QueryText(OneOffQueryResponseMessage), Identity(IdentityTokenMessage), Subscribe(SubscriptionUpdateMessage), TxUpdate(TransactionUpdateMessage), @@ -72,7 +72,8 @@ pub enum SerializableMessage { impl SerializableMessage { pub fn num_rows(&self) -> Option { match self { - Self::Query(msg) => Some(msg.num_rows()), + Self::QueryBinary(msg) => Some(msg.num_rows()), + Self::QueryText(msg) => Some(msg.num_rows()), Self::Subscribe(msg) => Some(msg.num_rows()), Self::TxUpdate(msg) => Some(msg.num_rows()), Self::Identity(_) => None, @@ -81,7 +82,7 @@ impl SerializableMessage { pub fn workload(&self) -> Option { match self { - Self::Query(_) => Some(WorkloadType::Sql), + Self::QueryBinary(_) | Self::QueryText(_) => Some(WorkloadType::Sql), Self::Subscribe(_) => Some(WorkloadType::Subscribe), Self::TxUpdate(_) => Some(WorkloadType::Update), Self::Identity(_) => None, @@ -93,7 +94,8 @@ impl ToProtocol for SerializableMessage { type Encoded = SwitchedServerMessage; fn to_protocol(self, protocol: Protocol) -> Self::Encoded { match self { - SerializableMessage::Query(msg) => msg.to_protocol(protocol), + SerializableMessage::QueryBinary(msg) => msg.to_protocol(protocol), + SerializableMessage::QueryText(msg) => msg.to_protocol(protocol), SerializableMessage::Identity(msg) => msg.to_protocol(protocol), SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol), SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol), @@ -243,42 +245,38 @@ impl ToProtocol for SubscriptionUpdateMessage { } #[derive(Debug)] -pub struct OneOffQueryResponseMessage { +pub struct OneOffQueryResponseMessage { pub message_id: Vec, pub error: Option, - pub results: Vec, + pub results: Vec>, pub total_host_execution_duration: u64, } -impl OneOffQueryResponseMessage { +impl OneOffQueryResponseMessage { fn num_rows(&self) -> usize { - self.results.iter().map(|t| t.data.len()).sum() + self.results.iter().map(|table| table.rows.len()).sum() } } -impl ToProtocol for OneOffQueryResponseMessage { +impl ToProtocol for OneOffQueryResponseMessage { type Encoded = SwitchedServerMessage; - fn to_protocol(self, protocol: Protocol) -> Self::Encoded { - fn convert(msg: OneOffQueryResponseMessage) -> ws::ServerMessage { - let tables = msg - .results - .into_iter() - .map(|table| ws::OneOffTable { - table_name: table.head.table_name.clone(), - rows: F::encode_list(table.data.into_iter()).0, - }) - .collect(); - ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse { - message_id: msg.message_id.into(), - error: msg.error.map(Into::into), - tables, - total_host_execution_duration_micros: msg.total_host_execution_duration, - }) - } + fn to_protocol(self, _: Protocol) -> Self::Encoded { + FormatSwitch::Bsatn(convert(self)) + } +} - match protocol { - Protocol::Text => FormatSwitch::Json(convert(self)), - Protocol::Binary => FormatSwitch::Bsatn(convert(self)), - } +impl ToProtocol for OneOffQueryResponseMessage { + type Encoded = SwitchedServerMessage; + fn to_protocol(self, _: Protocol) -> Self::Encoded { + FormatSwitch::Json(convert(self)) } } + +fn convert(msg: OneOffQueryResponseMessage) -> ws::ServerMessage { + ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse { + message_id: msg.message_id.into(), + error: msg.error.map(Into::into), + tables: msg.results.into_boxed_slice(), + total_host_execution_duration_micros: msg.total_host_execution_duration, + }) +} diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 1a7258d159..6c96e52a09 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -12,7 +12,7 @@ use crate::hash::Hash; use crate::identity::Identity; use crate::messages::control_db::Database; use crate::replica_context::ReplicaContext; -use crate::sql; +use crate::sql::ast::SchemaViewer; use crate::subscription::module_subscription_actor::ModuleSubscriptions; use crate::util::lending_pool::{Closed, LendingPool, LentResource, PoolClosed}; use crate::worker_metrics::WORKER_METRICS; @@ -24,17 +24,18 @@ use indexmap::IndexSet; use itertools::Itertools; use smallvec::SmallVec; use spacetimedb_client_api_messages::timestamp::Timestamp; -use spacetimedb_client_api_messages::websocket::{Compression, QueryUpdate, WebsocketFormat}; +use spacetimedb_client_api_messages::websocket::{Compression, OneOffTable, QueryUpdate, WebsocketFormat}; use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_lib::identity::{AuthCtx, RequestId}; use spacetimedb_lib::Address; use spacetimedb_primitives::{col_list, TableId}; +use spacetimedb_query::SubscribePlan; use spacetimedb_sats::{algebraic_value, ProductValue}; use spacetimedb_schema::auto_migrate::AutoMigrateError; use spacetimedb_schema::def::deserialize::ReducerArgsDeserializeSeed; use spacetimedb_schema::def::{ModuleDef, ReducerDef}; -use spacetimedb_vm::relation::{MemTable, RelValue}; +use spacetimedb_vm::relation::RelValue; use std::fmt; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; @@ -836,16 +837,24 @@ impl ModuleHost { } #[tracing::instrument(skip_all)] - pub fn one_off_query(&self, caller_identity: Identity, query: String) -> Result, anyhow::Error> { + pub fn one_off_query( + &self, + caller_identity: Identity, + query: String, + ) -> Result, anyhow::Error> { let replica_ctx = self.replica_ctx(); let db = &replica_ctx.relational_db; let auth = AuthCtx::new(replica_ctx.owner_identity, caller_identity); log::debug!("One-off query: {query}"); db.with_read_only(Workload::Sql, |tx| { - let ast = sql::compiler::compile_sql(db, &auth, tx, &query)?; - sql::execute::execute_sql_tx(db, tx, &query, ast, auth)? - .context("One-off queries are not allowed to modify the database") + let schema_viewer = SchemaViewer::new(tx, &auth); + let plan = SubscribePlan::compile(&query, &schema_viewer)?; + plan.execute_with::>(tx, |rows, _| OneOffTable { + table_name: plan.table_name().to_owned().into_boxed_str(), + rows, + }) + .context("One-off queries are not allowed to modify the database") }) } diff --git a/crates/query/src/lib.rs b/crates/query/src/lib.rs index 8943838cd0..7c18a3c553 100644 --- a/crates/query/src/lib.rs +++ b/crates/query/src/lib.rs @@ -19,6 +19,14 @@ pub struct SubscribePlan { } impl SubscribePlan { + pub fn table_id(&self) -> TableId { + self.table_id + } + + pub fn table_name(&self) -> &str { + self.table_name.as_ref() + } + pub fn compile(sql: &str, tx: &impl SchemaView) -> Result { let ast = parse_subscription(sql)?; let sub = type_subscription(ast, tx)?; @@ -40,14 +48,25 @@ impl SubscribePlan { }) } - pub fn execute(&self, comp: Compression, tx: &impl Datastore) -> Result> { + pub fn execute(&self, tx: &impl Datastore) -> Result<(F::List, u64)> { execute_plan(&self.plan, tx, |iter| match iter { PlanIter::Index(iter) => F::encode_list(iter), PlanIter::Table(iter) => F::encode_list(iter), PlanIter::RowId(iter) => F::encode_list(iter), PlanIter::Tuple(iter) => F::encode_list(iter), }) - .map(|(inserts, num_rows)| { + } + + pub fn execute_with(&self, tx: &impl Datastore, f: impl Fn(F::List, u64) -> R) -> Result { + self.execute::(tx).map(|(list, n)| f(list, n)) + } + + pub fn collect_table_update( + &self, + comp: Compression, + tx: &impl Datastore, + ) -> Result> { + self.execute_with::>(tx, |inserts, num_rows| { let deletes = F::List::default(); let qu = QueryUpdate { deletes, inserts }; let update = F::into_query_update(qu, comp); @@ -63,7 +82,7 @@ where { plans .par_iter() - .map(|plan| plan.execute(comp, tx)) + .map(|plan| plan.collect_table_update(comp, tx)) .collect::>() .map(|tables| DatabaseUpdate { tables }) }