Skip to content

Commit

Permalink
wip: change manager interface
Browse files Browse the repository at this point in the history
  • Loading branch information
katallaxie authored Jul 9, 2024
1 parent e700d82 commit 84779e5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 52 deletions.
12 changes: 6 additions & 6 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ func (c *exampleController) Get() error {
}

type webSrv struct {
factory sse.SenderFactory
manager sse.Manager
}

func (w *webSrv) Start(ctx context.Context, ready server.ReadyFunc, run server.RunFunc) func() error {
Expand All @@ -550,7 +550,7 @@ func (w *webSrv) Start(ctx context.Context, ready server.ReadyFunc, run server.R
return &exampleController{}
}))

app.Get("/sse", sse.NewSSEHandler(w.factory))
app.Get("/sse", sse.NewSSEHandler(w.manager))

app.Post("/error", htmx.NewHxControllerHandler(func() htmx.Controller {
return &exampleController{}
Expand All @@ -569,15 +569,15 @@ func run(ctx context.Context) error {
log.SetFlags(0)
log.SetOutput(os.Stderr)

broadcast := sse.NewBroadcastManager(5)
manager := sse.NewBroadcastManager(5)

webSrv := &webSrv{
factory: broadcast.CreateSender(),
manager: manager,
}

s, _ := server.WithContext(ctx)

s.Listen(broadcast, true)
s.Listen(manager, true)
s.Listen(webSrv, true)

ticker := time.NewTicker(2 * time.Second)
Expand All @@ -588,7 +588,7 @@ func run(ctx context.Context) error {
case <-ctx.Done():
return
case t := <-ticker.C:
broadcast.Send(sse.NewMessage("demo", fmt.Sprintf("Hello, World! %s", t)))
manager.Send() <- sse.NewMessage("demo", fmt.Sprintf("Hello, World! %s", t))
}
}
}()
Expand Down
114 changes: 68 additions & 46 deletions sse/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ type Event interface {
String() string // Represent the envelope contents as a string for transmission.
}

// Sender is the interface for sending server-sent events.
type Sender interface {
// Client is the interface for sending server-sent events.
type Client interface {
ID() string
Events() chan Event
Close()
}

// SenderFactory is the interface for creating new senders.
type SenderFactory func(c *fiber.Ctx) Sender
// Manager is the interface for broadcasting messages to clients.
type Manager interface {
Add(Client)
Remove(Client)
Send() chan<- Event
}

// EventImpl is the default implementation of the Event interface.
type EventImpl struct {
Expand Down Expand Up @@ -59,8 +63,11 @@ func (e *EventImpl) String() string {
return sb.String()
}

var _ Manager = (*BroadcastManagerImpl)(nil)

// BroadcastManagerImpl is the default implementation of the BroadcastManager interface.
type BroadcastManagerImpl struct {
id string
broadcast chan Event
poolSize int
clients sync.Map
Expand All @@ -71,19 +78,30 @@ var _ server.Listener = (*BroadcastManagerImpl)(nil)
// NewBroadcastManager creates a new broadcast manager.
func NewBroadcastManager(poolSize int) *BroadcastManagerImpl {
return &BroadcastManagerImpl{
id: uuid.NewString(),
broadcast: make(chan Event),
poolSize: poolSize,
}
}

// Add adds a client to the broadcast manager.
func (b *BroadcastManagerImpl) Add(client Client) {
b.clients.Store(client.ID(), client)
}

// Remove removes a client from the broadcast manager.
func (b *BroadcastManagerImpl) Remove(client Client) {
b.clients.Delete(client.ID())
}

// Start starts the broadcast manager.
func (b *BroadcastManagerImpl) Start(ctx context.Context, ready server.ReadyFunc, run server.RunFunc) func() error {
return func() error {
if b.poolSize < 1 {
return server.NewError(fmt.Errorf("pool size must be greater than 0"))
}

b.startWorkers()
b.startWorkers(ctx)

ready()

Expand Down Expand Up @@ -111,7 +129,7 @@ func (c *ClientImpl) Events() chan Event {

// Close closes the client.
func (c *ClientImpl) Close() {
<-c.events
<-c.events // Drain the channel.
}

// NewClient creates a new client.
Expand All @@ -126,48 +144,43 @@ func NewClient(id string) *ClientImpl {
}
}

// CreateSender creates a new sender.
func (b *BroadcastManagerImpl) CreateSender() SenderFactory {
return func(c *fiber.Ctx) Sender {
client := NewClient("")
b.clients.Store(client.ID(), client)

return client
}
}

func (b *BroadcastManagerImpl) startWorkers() {
func (b *BroadcastManagerImpl) startWorkers(ctx context.Context) {
for i := 0; i < b.poolSize; i++ {
go func() {
for message := range b.broadcast {
b.clients.Range(func(key, value any) bool {
client, ok := value.(Sender)
if !ok {
return true
}
for {
select {
case msg := <-b.broadcast:
b.clients.Range(func(key, value any) bool {
client, ok := value.(Client)
if !ok {
return true
}

select {
case client.Events() <- msg:
default: // client not reachable
}

select {
case client.Events() <- message:
default:
}

return true
})
return true
})
case <-ctx.Done():
return
}
}
}()
}
}

// Send sends a message to all clients.
func (b *BroadcastManagerImpl) Send(event Event) {
b.broadcast <- event
func (b *BroadcastManagerImpl) Send() chan<- Event {
return b.broadcast
}

// Config is the configuration for the server-sent events server.
type Config struct{}

// NewSSEHandler creates a new server-sent events handler.
func NewSSEHandler(sender SenderFactory, config ...Config) fiber.Handler {
func NewSSEHandler(manager Manager, config ...Config) fiber.Handler {
_ = configDefault(config...)

return func(c *fiber.Ctx) error {
Expand All @@ -176,24 +189,33 @@ func NewSSEHandler(sender SenderFactory, config ...Config) fiber.Handler {
c.Set(fiber.HeaderConnection, "keep-alive")
c.Set(fiber.HeaderTransferEncoding, "chunked")

s := sender(c) // return a sender
client := NewClient(uuid.NewString())
manager.Add(client)

notify := c.Context().Done()

c.Status(fiber.StatusOK).Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
defer s.Close()
for msg := range s.Events() {
_, err := fmt.Fprint(w, msg.String())
if err != nil {
for {
select {
case <-notify:
manager.Remove(client)
fmt.Println("client removed")
return
}
case msg := <-client.Events():
fmt.Println("msg received")
_, err := fmt.Fprint(w, msg.String())
if err != nil {
return
}

err = w.Flush()
if err != nil {
// Refreshing page in web browser will establish a new
// SSE connection, but only (the last) one is alive, so
// dead connections must be closed here.
return
err = w.Flush()
if err != nil {
// Refreshing page in web browser will establish a new
// SSE connection, but only (the last) one is alive, so
// dead connections must be closed here.
return
}
}

}
}))

Expand Down

0 comments on commit 84779e5

Please sign in to comment.