diff --git a/include/crow/app.h b/include/crow/app.h index 923d91dd7b..2d464c5622 100644 --- a/include/crow/app.h +++ b/include/crow/app.h @@ -619,11 +619,26 @@ namespace crow void close_websockets() { - std::lock_guard lock{websockets_mutex_}; - for (auto websocket : websockets_) + std::vector> futures; { - CROW_LOG_INFO << "Quitting Websocket: " << websocket; - websocket->close("Websocket Closed"); + std::lock_guard lock{websockets_mutex_}; + futures.reserve(websockets_.size()); + for (auto& websocket : websockets_) + { + CROW_LOG_INFO << "Quitting Websocket: " << websocket; + auto done = std::make_shared>(); + futures.push_back(done->get_future()); + websocket->close("Websocket Closed", websocket::NormalClosure, std::move(done)); + } + } + // Wait for all close frames to be written to the wire + for (auto& f : futures) + { + // Use a timeout to avoid hanging forever if a connection is stuck + if (f.wait_for(std::chrono::seconds(5)) == std::future_status::timeout) + { + CROW_LOG_WARNING << "Timed out waiting for websocket close frame to send"; + } } } diff --git a/include/crow/websocket.h b/include/crow/websocket.h index 7f91c3dccf..14c9a70cb3 100644 --- a/include/crow/websocket.h +++ b/include/crow/websocket.h @@ -69,7 +69,7 @@ namespace crow // NOTE: Already documented in "crow/app.h" virtual void send_text(std::string msg) = 0; virtual void send_ping(std::string msg) = 0; virtual void send_pong(std::string msg) = 0; - virtual void close(std::string const& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure) = 0; + virtual void close(const std::string& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure, std::shared_ptr> done = nullptr) = 0; virtual std::string get_remote_ip() = 0; virtual std::string get_subprotocol() const = 0; virtual ~connection() = default; @@ -123,14 +123,14 @@ namespace crow // NOTE: Already documented in "crow/app.h" std::function&, void**)> accept_handler, bool mirror_protocols) { - auto conn = std::shared_ptr(new Connection(std::move(adaptor), + auto conn = std::shared_ptr(new Connection(std::move(adaptor), handler, max_payload, - std::move(open_handler), - std::move(message_handler), + std::move(open_handler), + std::move(message_handler), std::move(close_handler), - std::move(error_handler), + std::move(error_handler), std::move(accept_handler))); - + // Perform handshake validation if (!utility::string_equals(req.get_header_value("upgrade"), "websocket")) { @@ -255,9 +255,9 @@ namespace crow // NOTE: Already documented in "crow/app.h" /// /// Sets a flag to destroy the object once the message is sent. - void close(std::string const& msg, uint16_t status_code) override + void close(std::string const& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure, std::shared_ptr> done = nullptr) override { - dispatch([shared_this = this->shared_from_this(), msg, status_code]() mutable { + dispatch([shared_this = this->shared_from_this(), msg, status_code, done]() mutable { shared_this->has_sent_close_ = true; if (shared_this->has_recv_close_ && !shared_this->is_close_handler_called_) { @@ -272,10 +272,61 @@ namespace crow // NOTE: Already documented in "crow/app.h" shared_this->write_buffers_.emplace_back(std::move(header)); shared_this->write_buffers_.emplace_back(std::string(status_buf, 2)); shared_this->write_buffers_.emplace_back(msg); - shared_this->do_write(); + shared_this->do_write_with_completion(std::move(done)); }); } + void do_write_with_completion(std::shared_ptr> done) + { + if (sending_buffers_.empty()) { + if (write_buffers_.empty()) { + if (done) done->set_value(); + return; + } + + sending_buffers_.swap(write_buffers_); + std::vector buffers; + buffers.reserve(sending_buffers_.size()); + for (auto &s: sending_buffers_) + { + buffers.emplace_back(asio::buffer(s)); + } + auto watch = std::weak_ptr{anchor_}; + asio::async_write( + adaptor_.socket(), buffers, + [shared_this = this->shared_from_this(), watch, done](const error_code &ec, std::size_t) { + auto anchor = watch.lock(); + if (anchor == nullptr) { + if (done) done->set_value(); + return; + } + + shared_this->sending_buffers_.clear(); + if (!ec && !shared_this->close_connection_) + { + if (!shared_this->write_buffers_.empty()) + shared_this->do_write(); + } + if (shared_this->has_sent_close_) + shared_this->close_connection_ = true; + + // Signal that the close frame has been written + if (done) done->set_value(); + }); + } + else + { + // Buffers are currently being sent, fall back to normal write + // and signal immediately (close frame will be picked up by the + // in-flight write's completion handler via write_buffers_) + write_buffers_.insert(write_buffers_.end(), + sending_buffers_.begin(), sending_buffers_.end()); + // Actually this case shouldn't happen for close since we just + // populated write_buffers_ above. But just in case: + if (done) done->set_value(); + } + } + std::string get_remote_ip() override { return adaptor_.address(); @@ -286,7 +337,7 @@ namespace crow // NOTE: Already documented in "crow/app.h" max_payload_bytes_ = payload; } - /// Returns the matching client/server subprotocol, empty string if none matched. + /// Returns the matching client/server subprotocol, empty string if none matched. std::string get_subprotocol() const override { return subprotocol_;