Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions include/crow/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,26 @@ namespace crow

void close_websockets()
{
std::lock_guard<std::mutex> lock{websockets_mutex_};
for (auto websocket : websockets_)
std::vector<std::future<void>> futures;
{
CROW_LOG_INFO << "Quitting Websocket: " << websocket;
websocket->close("Websocket Closed");
std::lock_guard<std::mutex> lock{websockets_mutex_};
futures.reserve(websockets_.size());
for (auto& websocket : websockets_)
{
CROW_LOG_INFO << "Quitting Websocket: " << websocket;
auto done = std::make_shared<std::promise<void>>();
futures.push_back(done->get_future());
websocket->close("Websocket Closed", websocket::NormalClosure, std::move(done));
}
}
// Wait for all close frames to be written to the wire
for (auto& f : futures)
{
// Use a timeout to avoid hanging forever if a connection is stuck
if (f.wait_for(std::chrono::seconds(5)) == std::future_status::timeout)
{
CROW_LOG_WARNING << "Timed out waiting for websocket close frame to send";
}
}
}

Expand Down
71 changes: 61 additions & 10 deletions include/crow/websocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
virtual void send_text(std::string msg) = 0;
virtual void send_ping(std::string msg) = 0;
virtual void send_pong(std::string msg) = 0;
virtual void close(std::string const& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure) = 0;
virtual void close(const std::string& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure, std::shared_ptr<std::promise<void>> done = nullptr) = 0;
virtual std::string get_remote_ip() = 0;
virtual std::string get_subprotocol() const = 0;
virtual ~connection() = default;
Expand Down Expand Up @@ -123,14 +123,14 @@ namespace crow // NOTE: Already documented in "crow/app.h"
std::function<void(const crow::request&, std::optional<crow::response>&, void**)> accept_handler,
bool mirror_protocols)
{
auto conn = std::shared_ptr<Connection>(new Connection(std::move(adaptor),
auto conn = std::shared_ptr<Connection>(new Connection(std::move(adaptor),
handler, max_payload,
std::move(open_handler),
std::move(message_handler),
std::move(open_handler),
std::move(message_handler),
std::move(close_handler),
std::move(error_handler),
std::move(error_handler),
std::move(accept_handler)));

// Perform handshake validation
if (!utility::string_equals(req.get_header_value("upgrade"), "websocket"))
{
Expand Down Expand Up @@ -255,9 +255,9 @@ namespace crow // NOTE: Already documented in "crow/app.h"

///
/// Sets a flag to destroy the object once the message is sent.
void close(std::string const& msg, uint16_t status_code) override
void close(std::string const& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure, std::shared_ptr<std::promise<void>> done = nullptr) override
{
dispatch([shared_this = this->shared_from_this(), msg, status_code]() mutable {
dispatch([shared_this = this->shared_from_this(), msg, status_code, done]() mutable {
shared_this->has_sent_close_ = true;
if (shared_this->has_recv_close_ && !shared_this->is_close_handler_called_)
{
Expand All @@ -272,10 +272,61 @@ namespace crow // NOTE: Already documented in "crow/app.h"
shared_this->write_buffers_.emplace_back(std::move(header));
shared_this->write_buffers_.emplace_back(std::string(status_buf, 2));
shared_this->write_buffers_.emplace_back(msg);
shared_this->do_write();
shared_this->do_write_with_completion(std::move(done));
});
}

void do_write_with_completion(std::shared_ptr<std::promise<void>> done)
{
if (sending_buffers_.empty()) {
if (write_buffers_.empty()) {
if (done) done->set_value();
return;
}

sending_buffers_.swap(write_buffers_);
std::vector<asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto &s: sending_buffers_)
{
buffers.emplace_back(asio::buffer(s));
}
auto watch = std::weak_ptr<void>{anchor_};
asio::async_write(
adaptor_.socket(), buffers,
[shared_this = this->shared_from_this(), watch, done](const error_code &ec, std::size_t) {
auto anchor = watch.lock();
if (anchor == nullptr) {
if (done) done->set_value();
return;
}

shared_this->sending_buffers_.clear();
if (!ec && !shared_this->close_connection_)
{
if (!shared_this->write_buffers_.empty())
shared_this->do_write();
}
if (shared_this->has_sent_close_)
shared_this->close_connection_ = true;

// Signal that the close frame has been written
if (done) done->set_value();
});
}
else
{
// Buffers are currently being sent, fall back to normal write
// and signal immediately (close frame will be picked up by the
// in-flight write's completion handler via write_buffers_)
write_buffers_.insert(write_buffers_.end(),
sending_buffers_.begin(), sending_buffers_.end());
// Actually this case shouldn't happen for close since we just
// populated write_buffers_ above. But just in case:
if (done) done->set_value();
}
}

std::string get_remote_ip() override
{
return adaptor_.address();
Expand All @@ -286,7 +337,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
max_payload_bytes_ = payload;
}

/// Returns the matching client/server subprotocol, empty string if none matched.
/// Returns the matching client/server subprotocol, empty string if none matched.
std::string get_subprotocol() const override
{
return subprotocol_;
Expand Down
Loading