aboutsummaryrefslogtreecommitdiffstats
path: root/include/crow/websocket.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/crow/websocket.h')
-rw-r--r--include/crow/websocket.h506
1 files changed, 506 insertions, 0 deletions
diff --git a/include/crow/websocket.h b/include/crow/websocket.h
new file mode 100644
index 0000000..d21d7e9
--- /dev/null
+++ b/include/crow/websocket.h
@@ -0,0 +1,506 @@
+#pragma once
+#include <boost/algorithm/string/predicate.hpp>
+#include <boost/array.hpp>
+#include "crow/socket_adaptors.h"
+#include "crow/http_request.h"
+#include "crow/TinySHA1.hpp"
+
+namespace crow
+{
+ namespace websocket
+ {
+ enum class WebSocketReadState
+ {
+ MiniHeader,
+ Len16,
+ Len64,
+ Mask,
+ Payload,
+ };
+
+ struct connection
+ {
+ virtual void send_binary(const std::string& msg) = 0;
+ virtual void send_text(const std::string& msg) = 0;
+ virtual void close(const std::string& msg = "quit") = 0;
+ virtual ~connection(){}
+
+ void userdata(void* u) { userdata_ = u; }
+ void* userdata() { return userdata_; }
+
+ private:
+ void* userdata_;
+ };
+
+ template <typename Adaptor>
+ class Connection : public connection
+ {
+ public:
+ Connection(const crow::request& req, Adaptor&& adaptor,
+ std::function<void(crow::websocket::connection&)> open_handler,
+ std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
+ std::function<void(crow::websocket::connection&, const std::string&)> close_handler,
+ std::function<void(crow::websocket::connection&)> error_handler,
+ std::function<bool(const crow::request&)> accept_handler)
+ : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler))
+ , accept_handler_(std::move(accept_handler))
+ {
+ if (!boost::iequals(req.get_header_value("upgrade"), "websocket"))
+ {
+ adaptor.close();
+ delete this;
+ return;
+ }
+
+ if (accept_handler_)
+ {
+ if (!accept_handler_(req))
+ {
+ adaptor.close();
+ delete this;
+ return;
+ }
+ }
+
+ // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+ // Sec-WebSocket-Version: 13
+ std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+ sha1::SHA1 s;
+ s.processBytes(magic.data(), magic.size());
+ uint8_t digest[20];
+ s.getDigestBytes(digest);
+ start(crow::utility::base64encode((char*)digest, 20));
+ }
+
+ template<typename CompletionHandler>
+ void dispatch(CompletionHandler handler)
+ {
+ adaptor_.get_io_service().dispatch(handler);
+ }
+
+ template<typename CompletionHandler>
+ void post(CompletionHandler handler)
+ {
+ adaptor_.get_io_service().post(handler);
+ }
+
+ void send_pong(const std::string& msg)
+ {
+ dispatch([this, msg]{
+ char buf[3] = "\x8A\x00";
+ buf[1] += msg.size();
+ write_buffers_.emplace_back(buf, buf+2);
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ void send_binary(const std::string& msg) override
+ {
+ dispatch([this, msg]{
+ auto header = build_header(2, msg.size());
+ write_buffers_.emplace_back(std::move(header));
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ void send_text(const std::string& msg) override
+ {
+ dispatch([this, msg]{
+ auto header = build_header(1, msg.size());
+ write_buffers_.emplace_back(std::move(header));
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ void close(const std::string& msg) override
+ {
+ dispatch([this, msg]{
+ has_sent_close_ = true;
+ if (has_recv_close_ && !is_close_handler_called_)
+ {
+ is_close_handler_called_ = true;
+ if (close_handler_)
+ close_handler_(*this, msg);
+ }
+ auto header = build_header(0x8, msg.size());
+ write_buffers_.emplace_back(std::move(header));
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ protected:
+
+ std::string build_header(int opcode, size_t size)
+ {
+ char buf[2+8] = "\x80\x00";
+ buf[0] += opcode;
+ if (size < 126)
+ {
+ buf[1] += size;
+ return {buf, buf+2};
+ }
+ else if (size < 0x10000)
+ {
+ buf[1] += 126;
+ *(uint16_t*)(buf+2) = htons((uint16_t)size);
+ return {buf, buf+4};
+ }
+ else
+ {
+ buf[1] += 127;
+ *(uint64_t*)(buf+2) = ((1==htonl(1)) ? (uint64_t)size : ((uint64_t)htonl((size) & 0xFFFFFFFF) << 32) | htonl((size) >> 32));
+ return {buf, buf+10};
+ }
+ }
+
+ void start(std::string&& hello)
+ {
+ static std::string header = "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: ";
+ static std::string crlf = "\r\n";
+ write_buffers_.emplace_back(header);
+ write_buffers_.emplace_back(std::move(hello));
+ write_buffers_.emplace_back(crlf);
+ write_buffers_.emplace_back(crlf);
+ do_write();
+ if (open_handler_)
+ open_handler_(*this);
+ do_read();
+ }
+
+ void do_read()
+ {
+ is_reading = true;
+ switch(state_)
+ {
+ case WebSocketReadState::MiniHeader:
+ {
+ //boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1),
+ adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+ mini_header_ = htons(mini_header_);
+#ifdef CROW_ENABLE_DEBUG
+
+ if (!ec && bytes_transferred != 2)
+ {
+ throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec && ((mini_header_ & 0x80) == 0x80))
+ {
+ if ((mini_header_ & 0x7f) == 127)
+ {
+ state_ = WebSocketReadState::Len64;
+ }
+ else if ((mini_header_ & 0x7f) == 126)
+ {
+ state_ = WebSocketReadState::Len16;
+ }
+ else
+ {
+ remaining_length_ = mini_header_ & 0x7f;
+ state_ = WebSocketReadState::Mask;
+ }
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ adaptor_.close();
+ if (error_handler_)
+ error_handler_(*this);
+ check_destroy();
+ }
+ });
+ }
+ break;
+ case WebSocketReadState::Len16:
+ {
+ remaining_length_ = 0;
+ uint16_t remaining_length16_ = 0;
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length16_, 2),
+ [this,&remaining_length16_](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+ remaining_length16_ = ntohs(remaining_length16_);
+ remaining_length_ = remaining_length16_;
+#ifdef CROW_ENABLE_DEBUG
+ if (!ec && bytes_transferred != 2)
+ {
+ throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec)
+ {
+ state_ = WebSocketReadState::Mask;
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ adaptor_.close();
+ if (error_handler_)
+ error_handler_(*this);
+ check_destroy();
+ }
+ });
+ }
+ break;
+ case WebSocketReadState::Len64:
+ {
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+ remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32));
+#ifdef CROW_ENABLE_DEBUG
+ if (!ec && bytes_transferred != 8)
+ {
+ throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec)
+ {
+ state_ = WebSocketReadState::Mask;
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ adaptor_.close();
+ if (error_handler_)
+ error_handler_(*this);
+ check_destroy();
+ }
+ });
+ }
+ break;
+ case WebSocketReadState::Mask:
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+#ifdef CROW_ENABLE_DEBUG
+ if (!ec && bytes_transferred != 4)
+ {
+ throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec)
+ {
+ state_ = WebSocketReadState::Payload;
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ if (error_handler_)
+ error_handler_(*this);
+ adaptor_.close();
+ }
+ });
+ break;
+ case WebSocketReadState::Payload:
+ {
+ size_t to_read = buffer_.size();
+ if (remaining_length_ < to_read)
+ to_read = remaining_length_;
+ adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+
+ if (!ec)
+ {
+ fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred);
+ remaining_length_ -= bytes_transferred;
+ if (remaining_length_ == 0)
+ {
+ handle_fragment();
+ state_ = WebSocketReadState::MiniHeader;
+ do_read();
+ }
+ }
+ else
+ {
+ close_connection_ = true;
+ if (error_handler_)
+ error_handler_(*this);
+ adaptor_.close();
+ }
+ });
+ }
+ break;
+ }
+ }
+
+ bool is_FIN()
+ {
+ return mini_header_ & 0x8000;
+ }
+
+ int opcode()
+ {
+ return (mini_header_ & 0x0f00) >> 8;
+ }
+
+ void handle_fragment()
+ {
+ for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++)
+ {
+ fragment_[i] ^= ((char*)&mask_)[i%4];
+ }
+ switch(opcode())
+ {
+ case 0: // Continuation
+ {
+ message_ += fragment_;
+ if (is_FIN())
+ {
+ if (message_handler_)
+ message_handler_(*this, message_, is_binary_);
+ message_.clear();
+ }
+ }
+ case 1: // Text
+ {
+ is_binary_ = false;
+ message_ += fragment_;
+ if (is_FIN())
+ {
+ if (message_handler_)
+ message_handler_(*this, message_, is_binary_);
+ message_.clear();
+ }
+ }
+ break;
+ case 2: // Binary
+ {
+ is_binary_ = true;
+ message_ += fragment_;
+ if (is_FIN())
+ {
+ if (message_handler_)
+ message_handler_(*this, message_, is_binary_);
+ message_.clear();
+ }
+ }
+ break;
+ case 0x8: // Close
+ {
+ has_recv_close_ = true;
+ if (!has_sent_close_)
+ {
+ close(fragment_);
+ }
+ else
+ {
+ adaptor_.close();
+ close_connection_ = true;
+ if (!is_close_handler_called_)
+ {
+ if (close_handler_)
+ close_handler_(*this, fragment_);
+ is_close_handler_called_ = true;
+ }
+ check_destroy();
+ }
+ }
+ break;
+ case 0x9: // Ping
+ {
+ send_pong(fragment_);
+ }
+ break;
+ case 0xA: // Pong
+ {
+ pong_received_ = true;
+ }
+ break;
+ }
+
+ fragment_.clear();
+ }
+
+ void do_write()
+ {
+ if (sending_buffers_.empty())
+ {
+ sending_buffers_.swap(write_buffers_);
+ std::vector<boost::asio::const_buffer> buffers;
+ buffers.reserve(sending_buffers_.size());
+ for(auto& s:sending_buffers_)
+ {
+ buffers.emplace_back(boost::asio::buffer(s));
+ }
+ boost::asio::async_write(adaptor_.socket(), buffers,
+ [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/)
+ {
+ sending_buffers_.clear();
+ if (!ec && !close_connection_)
+ {
+ if (!write_buffers_.empty())
+ do_write();
+ if (has_sent_close_)
+ close_connection_ = true;
+ }
+ else
+ {
+ close_connection_ = true;
+ check_destroy();
+ }
+ });
+ }
+ }
+
+ void check_destroy()
+ {
+ //if (has_sent_close_ && has_recv_close_)
+ if (!is_close_handler_called_)
+ if (close_handler_)
+ close_handler_(*this, "uncleanly");
+ if (sending_buffers_.empty() && !is_reading)
+ delete this;
+ }
+ private:
+ Adaptor adaptor_;
+
+ std::vector<std::string> sending_buffers_;
+ std::vector<std::string> write_buffers_;
+
+ boost::array<char, 4096> buffer_;
+ bool is_binary_;
+ std::string message_;
+ std::string fragment_;
+ WebSocketReadState state_{WebSocketReadState::MiniHeader};
+ uint64_t remaining_length_{0};
+ bool close_connection_{false};
+ bool is_reading{false};
+ uint32_t mask_;
+ uint16_t mini_header_;
+ bool has_sent_close_{false};
+ bool has_recv_close_{false};
+ bool error_occured_{false};
+ bool pong_received_{false};
+ bool is_close_handler_called_{false};
+
+ std::function<void(crow::websocket::connection&)> open_handler_;
+ std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
+ std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
+ std::function<void(crow::websocket::connection&)> error_handler_;
+ std::function<bool(const crow::request&)> accept_handler_;
+ };
+ }
+}