diff --git a/muxrpc/handlers/tunnel/server/connect.go b/muxrpc/handlers/tunnel/server/connect.go index 0aed687..8cc4b44 100644 --- a/muxrpc/handlers/tunnel/server/connect.go +++ b/muxrpc/handlers/tunnel/server/connect.go @@ -13,6 +13,7 @@ import ( "go.cryptoscope.co/muxrpc/v2" "github.com/ssb-ngi-pointer/go-ssb-room/internal/network" + "github.com/ssb-ngi-pointer/go-ssb-room/roomstate" refs "go.mindeco.de/ssb-refs" ) @@ -26,7 +27,28 @@ type connectWithOriginArg struct { Origin refs.FeedRef `json:"origin"` // who started the call } -func (h *Handler) connect(ctx context.Context, req *muxrpc.Request, peerSrc *muxrpc.ByteSource, peerSnk *muxrpc.ByteSink) error { +type connectHandler struct { + logger kitlog.Logger + self refs.FeedRef + + state *roomstate.Manager +} + +// HandleConnect for tunnel.connect makes sure peers whos muxrpc session ends are removed from the room state +func (h connectHandler) HandleConnect(ctx context.Context, edp muxrpc.Endpoint) { + // block until the channel is closed when the rpc session ends + <-ctx.Done() + + peer, err := network.GetFeedRefFromAddr(edp.Remote()) + if err != nil { + return + } + + h.state.Remove(*peer) +} + +// HandleDuplex here implements the tunnel.connect behavior of the server-side. It receives incoming events +func (h connectHandler) HandleDuplex(ctx context.Context, req *muxrpc.Request, peerSrc *muxrpc.ByteSource, peerSnk *muxrpc.ByteSink) error { // unpack arguments var args []connectArg err := json.Unmarshal(req.RawArgs, &args) diff --git a/muxrpc/handlers/tunnel/server/plugin.go b/muxrpc/handlers/tunnel/server/plugin.go index 6a05df7..80499bf 100644 --- a/muxrpc/handlers/tunnel/server/plugin.go +++ b/muxrpc/handlers/tunnel/server/plugin.go @@ -43,5 +43,9 @@ func (h *Handler) Register(mux typemux.HandlerMux, namespace muxrpc.Method) { mux.RegisterSource(append(namespace, "endpoints"), typemux.SourceFunc(h.endpoints)) - mux.RegisterDuplex(append(namespace, "connect"), typemux.DuplexFunc(h.connect)) + mux.RegisterDuplex(append(namespace, "connect"), connectHandler{ + logger: h.logger, + self: h.self, + state: h.state, + }) }