diff options
Diffstat (limited to 'include/crow/websocket.h')
-rw-r--r-- | include/crow/websocket.h | 506 |
1 files changed, 506 insertions, 0 deletions
diff --git a/include/crow/websocket.h b/include/crow/websocket.h new file mode 100644 index 0000000..d21d7e9 --- /dev/null +++ b/include/crow/websocket.h @@ -0,0 +1,506 @@ +#pragma once +#include <boost/algorithm/string/predicate.hpp> +#include <boost/array.hpp> +#include "crow/socket_adaptors.h" +#include "crow/http_request.h" +#include "crow/TinySHA1.hpp" + +namespace crow +{ + namespace websocket + { + enum class WebSocketReadState + { + MiniHeader, + Len16, + Len64, + Mask, + Payload, + }; + + struct connection + { + virtual void send_binary(const std::string& msg) = 0; + virtual void send_text(const std::string& msg) = 0; + virtual void close(const std::string& msg = "quit") = 0; + virtual ~connection(){} + + void userdata(void* u) { userdata_ = u; } + void* userdata() { return userdata_; } + + private: + void* userdata_; + }; + + template <typename Adaptor> + class Connection : public connection + { + public: + Connection(const crow::request& req, Adaptor&& adaptor, + std::function<void(crow::websocket::connection&)> open_handler, + std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler, + std::function<void(crow::websocket::connection&, const std::string&)> close_handler, + std::function<void(crow::websocket::connection&)> error_handler, + std::function<bool(const crow::request&)> accept_handler) + : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler)) + , accept_handler_(std::move(accept_handler)) + { + if (!boost::iequals(req.get_header_value("upgrade"), "websocket")) + { + adaptor.close(); + delete this; + return; + } + + if (accept_handler_) + { + if (!accept_handler_(req)) + { + adaptor.close(); + delete this; + return; + } + } + + // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== + // Sec-WebSocket-Version: 13 + std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + sha1::SHA1 s; + s.processBytes(magic.data(), magic.size()); + uint8_t digest[20]; + s.getDigestBytes(digest); + start(crow::utility::base64encode((char*)digest, 20)); + } + + template<typename CompletionHandler> + void dispatch(CompletionHandler handler) + { + adaptor_.get_io_service().dispatch(handler); + } + + template<typename CompletionHandler> + void post(CompletionHandler handler) + { + adaptor_.get_io_service().post(handler); + } + + void send_pong(const std::string& msg) + { + dispatch([this, msg]{ + char buf[3] = "\x8A\x00"; + buf[1] += msg.size(); + write_buffers_.emplace_back(buf, buf+2); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_binary(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(2, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_text(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(1, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void close(const std::string& msg) override + { + dispatch([this, msg]{ + has_sent_close_ = true; + if (has_recv_close_ && !is_close_handler_called_) + { + is_close_handler_called_ = true; + if (close_handler_) + close_handler_(*this, msg); + } + auto header = build_header(0x8, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + protected: + + std::string build_header(int opcode, size_t size) + { + char buf[2+8] = "\x80\x00"; + buf[0] += opcode; + if (size < 126) + { + buf[1] += size; + return {buf, buf+2}; + } + else if (size < 0x10000) + { + buf[1] += 126; + *(uint16_t*)(buf+2) = htons((uint16_t)size); + return {buf, buf+4}; + } + else + { + buf[1] += 127; + *(uint64_t*)(buf+2) = ((1==htonl(1)) ? (uint64_t)size : ((uint64_t)htonl((size) & 0xFFFFFFFF) << 32) | htonl((size) >> 32)); + return {buf, buf+10}; + } + } + + void start(std::string&& hello) + { + static std::string header = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "; + static std::string crlf = "\r\n"; + write_buffers_.emplace_back(header); + write_buffers_.emplace_back(std::move(hello)); + write_buffers_.emplace_back(crlf); + write_buffers_.emplace_back(crlf); + do_write(); + if (open_handler_) + open_handler_(*this); + do_read(); + } + + void do_read() + { + is_reading = true; + switch(state_) + { + case WebSocketReadState::MiniHeader: + { + //boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1), + adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + mini_header_ = htons(mini_header_); +#ifdef CROW_ENABLE_DEBUG + + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?"); + } +#endif + + if (!ec && ((mini_header_ & 0x80) == 0x80)) + { + if ((mini_header_ & 0x7f) == 127) + { + state_ = WebSocketReadState::Len64; + } + else if ((mini_header_ & 0x7f) == 126) + { + state_ = WebSocketReadState::Len16; + } + else + { + remaining_length_ = mini_header_ & 0x7f; + state_ = WebSocketReadState::Mask; + } + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len16: + { + remaining_length_ = 0; + uint16_t remaining_length16_ = 0; + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length16_, 2), + [this,&remaining_length16_](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length16_ = ntohs(remaining_length16_); + remaining_length_ = remaining_length16_; +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len64: + { + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32)); +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 8) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Mask: + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 4) + { + throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Payload; + do_read(); + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + break; + case WebSocketReadState::Payload: + { + size_t to_read = buffer_.size(); + if (remaining_length_ < to_read) + to_read = remaining_length_; + adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + + if (!ec) + { + fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred); + remaining_length_ -= bytes_transferred; + if (remaining_length_ == 0) + { + handle_fragment(); + state_ = WebSocketReadState::MiniHeader; + do_read(); + } + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + } + break; + } + } + + bool is_FIN() + { + return mini_header_ & 0x8000; + } + + int opcode() + { + return (mini_header_ & 0x0f00) >> 8; + } + + void handle_fragment() + { + for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++) + { + fragment_[i] ^= ((char*)&mask_)[i%4]; + } + switch(opcode()) + { + case 0: // Continuation + { + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + case 1: // Text + { + is_binary_ = false; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 2: // Binary + { + is_binary_ = true; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 0x8: // Close + { + has_recv_close_ = true; + if (!has_sent_close_) + { + close(fragment_); + } + else + { + adaptor_.close(); + close_connection_ = true; + if (!is_close_handler_called_) + { + if (close_handler_) + close_handler_(*this, fragment_); + is_close_handler_called_ = true; + } + check_destroy(); + } + } + break; + case 0x9: // Ping + { + send_pong(fragment_); + } + break; + case 0xA: // Pong + { + pong_received_ = true; + } + break; + } + + fragment_.clear(); + } + + void do_write() + { + if (sending_buffers_.empty()) + { + sending_buffers_.swap(write_buffers_); + std::vector<boost::asio::const_buffer> buffers; + buffers.reserve(sending_buffers_.size()); + for(auto& s:sending_buffers_) + { + buffers.emplace_back(boost::asio::buffer(s)); + } + boost::asio::async_write(adaptor_.socket(), buffers, + [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/) + { + sending_buffers_.clear(); + if (!ec && !close_connection_) + { + if (!write_buffers_.empty()) + do_write(); + if (has_sent_close_) + close_connection_ = true; + } + else + { + close_connection_ = true; + check_destroy(); + } + }); + } + } + + void check_destroy() + { + //if (has_sent_close_ && has_recv_close_) + if (!is_close_handler_called_) + if (close_handler_) + close_handler_(*this, "uncleanly"); + if (sending_buffers_.empty() && !is_reading) + delete this; + } + private: + Adaptor adaptor_; + + std::vector<std::string> sending_buffers_; + std::vector<std::string> write_buffers_; + + boost::array<char, 4096> buffer_; + bool is_binary_; + std::string message_; + std::string fragment_; + WebSocketReadState state_{WebSocketReadState::MiniHeader}; + uint64_t remaining_length_{0}; + bool close_connection_{false}; + bool is_reading{false}; + uint32_t mask_; + uint16_t mini_header_; + bool has_sent_close_{false}; + bool has_recv_close_{false}; + bool error_occured_{false}; + bool pong_received_{false}; + bool is_close_handler_called_{false}; + + std::function<void(crow::websocket::connection&)> open_handler_; + std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_; + std::function<void(crow::websocket::connection&, const std::string&)> close_handler_; + std::function<void(crow::websocket::connection&)> error_handler_; + std::function<bool(const crow::request&)> accept_handler_; + }; + } +} |