Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ push-based execution #2061

Open
joocer opened this issue Oct 12, 2024 · 1 comment
Open

✨ push-based execution #2061

joocer opened this issue Oct 12, 2024 · 1 comment

Comments

@joocer
Copy link
Contributor

joocer commented Oct 12, 2024

as a stepping stone for parallel execution, refactor operators to support push-based execution.

@joocer
Copy link
Contributor Author

joocer commented Oct 16, 2024

have been playing with writing a push-based version of the execution engine

from opteryx.third_party.travers import Graph
import time
import opteryx
from orso.tools import random_string
from orso import DataFrame
from typing import Optional

from opteryx.compiled.structures.hash_table import hash_join_map
from opteryx.utils.arrow import align_tables

EOS = object()  

class Node:
    def __init__(self, node_type: str):
        self.node_type = node_type
        self.execution_time = 0
        self.calls = 0
        self.records_in = 0
        self.records_out = 0
        self.identity = random_string()

    def __call__(self, morsel: Optional[DataFrame]) -> Optional[DataFrame]:
        if morsel is not None and morsel != EOS:
            self.records_in += 1
        start = time.time_ns()
        result = self.execute(morsel)
        self.execution_time += (time.time_ns() - start)
        self.calls += 1
        if result is not None and result != EOS:
            self.records_out += 1
        return result

    def execute(self, morsel: Optional[DataFrame]) -> Optional[DataFrame]:
        if morsel == EOS:
            return morsel
        return morsel + self.node_type

    def __str__(self):
        return f"{self.node_type} ({self.sensors()})"
    
    def sensors(self):
        return {"calls": self.calls, "execution_time": self.execution_time, "records_in": self.records_in, "records_out": self.records_out}

class PumpNode(Node):
    def __init__(self, node_type: str, data):
        Node.__init__(self, node_type)
        self.data = data  # Data that this pump will yield

    def __call__(self, morsel: Optional[DataFrame]) -> Optional[DataFrame]:
        self.calls += 1
        start = time.time_ns()

        self.execution_time += (time.time_ns() - start)
        self.records_out += 1
        yield opteryx.query(f"SELECT * FROM {self.data}")
        start = time.time_ns()
        yield EOS

class GreedyNode(Node):
    def __init__(self, node_type: str):
        Node.__init__(self, node_type)
        self.collector:DataFrame = None

    def execute(self, morsel: Optional[DataFrame]) -> Optional[DataFrame]:
        if morsel == EOS:
            return self.collector.group_by(["id"]).max(["name"])
        if self.collector is None:
            self.collector = morsel
        else:
            self.collector.append([m for m in morsel])
        return None  # Nothing to yield until EOS is received

class FilterNode(Node):
    def __init__(self, node_type: str):
        Node.__init__(self, node_type)

    def execute(self, morsel: Optional[DataFrame]) -> Optional[DataFrame]:
        if morsel == EOS:
            return EOS
        return morsel.query(lambda x: x[0] % 2 == 0)

class JoinNode(Node):
    def __init__(self, node_type: str):
        Node.__init__(self, node_type)
        self.left_buffer = None
        self.stream = 'left'

    def execute(self, morsel: Optional[DataFrame]) -> Optional[DataFrame]:
        if self.stream == 'left':
            if morsel == EOS:
                self.left_buffer = hash_join_map(self.left_buffer.arrow(), ["id"])
                self.stream = 'right'
            else:
                self.left_buffer = morsel
            return None

        if morsel == EOS:
            return EOS
        
        l_indexes = []
        r_indexes = []
        right_hash_map = hash_join_map(morsel.arrow(), ["id"])
        for k, v in right_hash_map.hash_table.items():
            rows = self.left_buffer.get(k)
            if rows:
                l_indexes.extend(rows)
                r_indexes.extend(v)
        
        return DataFrame.from_arrow(align_tables(morsel.arrow(), morsel.arrow(), l_indexes, r_indexes))


import string

et = Graph()

# Nodes definitions
et.add_node("data_source_a", PumpNode("DataSourcePumpA", "$planets"))
et.add_node("data_source_b", PumpNode("DataSourcePumpB", "$satellites"))
et.add_node("transform_fruits_a", FilterNode("TransformFruitsA"))
et.add_node("transform_fruits_b", FilterNode("TransformFruitsB"))
et.add_node("join_fruits", JoinNode("JoinFruits"))
et.add_node("batch_group", GreedyNode("BatchGroup"))
et.add_node("enrich_nutrients", FilterNode("EnrichNutrients"))

# Edges between nodes to create the execution plan
et.add_edge("data_source_a", "transform_fruits_a")
et.add_edge("data_source_b", "transform_fruits_b")
et.add_edge("transform_fruits_a", "join_fruits")
et.add_edge("transform_fruits_b", "join_fruits")
et.add_edge("join_fruits", "batch_group")

class SerialExecutionEngine:

    def __init__(self, plan: Graph):
        self.plan = plan

    def execute(self):
        pump_nodes = self.plan.get_entry_points()
        for pump_node in pump_nodes:
            pump_instance = self.plan[pump_node]
            for morsel in pump_instance(None):
                yield from self.process_node(pump_node, morsel)

    def process_node(self, nid, morsel):
        node = self.plan[nid]

        if isinstance(node, PumpNode):
            children = [t for s, t, r in self.plan.outgoing_edges(nid)]
            for child in children:
                yield from self.process_node(child, morsel)
        else:
            result = node(morsel)
            if result is not None:
                children = [t for s, t, r in self.plan.outgoing_edges(nid)]
                for child in children:
                    yield from self.process_node(child, result)
                if len(children) == 0:
                    yield result

se = SerialExecutionEngine(et)
for _ in se.execute():
    pass

print(_)

print(et.draw())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant