From 967adf0de55afcb52881cdb1a7b16788c7c283db Mon Sep 17 00:00:00 2001 From: ipknHama Date: Sun, 28 Aug 2016 14:46:31 +0900 Subject: Add websocket feature --- include/TinySHA1.hpp | 196 +++++++++++++++++++ include/crow.h | 6 + include/http_connection.h | 16 ++ include/http_request.h | 16 ++ include/http_server.h | 8 +- include/mustache.h | 6 +- include/parser.h | 5 + include/routing.h | 167 +++++++++++++++- include/settings.h | 2 +- include/socket_adaptors.h | 24 ++- include/utility.h | 42 ++++ include/websocket.h | 482 ++++++++++++++++++++++++++++++++++++++++++++++ 12 files changed, 961 insertions(+), 9 deletions(-) create mode 100644 include/TinySHA1.hpp create mode 100644 include/websocket.h (limited to 'include') diff --git a/include/TinySHA1.hpp b/include/TinySHA1.hpp new file mode 100644 index 0000000..70af046 --- /dev/null +++ b/include/TinySHA1.hpp @@ -0,0 +1,196 @@ +/* + * + * TinySHA1 - a header only implementation of the SHA1 algorithm in C++. Based + * on the implementation in boost::uuid::details. + * + * SHA1 Wikipedia Page: http://en.wikipedia.org/wiki/SHA-1 + * + * Copyright (c) 2012-22 SAURAV MOHAPATRA + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ +#ifndef _TINY_SHA1_HPP_ +#define _TINY_SHA1_HPP_ +#include +#include +#include +#include +namespace sha1 +{ + class SHA1 + { + public: + typedef uint32_t digest32_t[5]; + typedef uint8_t digest8_t[20]; + inline static uint32_t LeftRotate(uint32_t value, size_t count) { + return (value << count) ^ (value >> (32-count)); + } + SHA1(){ reset(); } + virtual ~SHA1() {} + SHA1(const SHA1& s) { *this = s; } + const SHA1& operator = (const SHA1& s) { + memcpy(m_digest, s.m_digest, 5 * sizeof(uint32_t)); + memcpy(m_block, s.m_block, 64); + m_blockByteIndex = s.m_blockByteIndex; + m_byteCount = s.m_byteCount; + return *this; + } + SHA1& reset() { + m_digest[0] = 0x67452301; + m_digest[1] = 0xEFCDAB89; + m_digest[2] = 0x98BADCFE; + m_digest[3] = 0x10325476; + m_digest[4] = 0xC3D2E1F0; + m_blockByteIndex = 0; + m_byteCount = 0; + return *this; + } + SHA1& processByte(uint8_t octet) { + this->m_block[this->m_blockByteIndex++] = octet; + ++this->m_byteCount; + if(m_blockByteIndex == 64) { + this->m_blockByteIndex = 0; + processBlock(); + } + return *this; + } + SHA1& processBlock(const void* const start, const void* const end) { + const uint8_t* begin = static_cast(start); + const uint8_t* finish = static_cast(end); + while(begin != finish) { + processByte(*begin); + begin++; + } + return *this; + } + SHA1& processBytes(const void* const data, size_t len) { + const uint8_t* block = static_cast(data); + processBlock(block, block + len); + return *this; + } + const uint32_t* getDigest(digest32_t digest) { + size_t bitCount = this->m_byteCount * 8; + processByte(0x80); + if (this->m_blockByteIndex > 56) { + while (m_blockByteIndex != 0) { + processByte(0); + } + while (m_blockByteIndex < 56) { + processByte(0); + } + } else { + while (m_blockByteIndex < 56) { + processByte(0); + } + } + processByte(0); + processByte(0); + processByte(0); + processByte(0); + processByte( static_cast((bitCount>>24) & 0xFF)); + processByte( static_cast((bitCount>>16) & 0xFF)); + processByte( static_cast((bitCount>>8 ) & 0xFF)); + processByte( static_cast((bitCount) & 0xFF)); + + memcpy(digest, m_digest, 5 * sizeof(uint32_t)); + return digest; + } + const uint8_t* getDigestBytes(digest8_t digest) { + digest32_t d32; + getDigest(d32); + size_t di = 0; + digest[di++] = ((d32[0] >> 24) & 0xFF); + digest[di++] = ((d32[0] >> 16) & 0xFF); + digest[di++] = ((d32[0] >> 8) & 0xFF); + digest[di++] = ((d32[0]) & 0xFF); + + digest[di++] = ((d32[1] >> 24) & 0xFF); + digest[di++] = ((d32[1] >> 16) & 0xFF); + digest[di++] = ((d32[1] >> 8) & 0xFF); + digest[di++] = ((d32[1]) & 0xFF); + + digest[di++] = ((d32[2] >> 24) & 0xFF); + digest[di++] = ((d32[2] >> 16) & 0xFF); + digest[di++] = ((d32[2] >> 8) & 0xFF); + digest[di++] = ((d32[2]) & 0xFF); + + digest[di++] = ((d32[3] >> 24) & 0xFF); + digest[di++] = ((d32[3] >> 16) & 0xFF); + digest[di++] = ((d32[3] >> 8) & 0xFF); + digest[di++] = ((d32[3]) & 0xFF); + + digest[di++] = ((d32[4] >> 24) & 0xFF); + digest[di++] = ((d32[4] >> 16) & 0xFF); + digest[di++] = ((d32[4] >> 8) & 0xFF); + digest[di++] = ((d32[4]) & 0xFF); + return digest; + } + + protected: + void processBlock() { + uint32_t w[80]; + for (size_t i = 0; i < 16; i++) { + w[i] = (m_block[i*4 + 0] << 24); + w[i] |= (m_block[i*4 + 1] << 16); + w[i] |= (m_block[i*4 + 2] << 8); + w[i] |= (m_block[i*4 + 3]); + } + for (size_t i = 16; i < 80; i++) { + w[i] = LeftRotate((w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]), 1); + } + + uint32_t a = m_digest[0]; + uint32_t b = m_digest[1]; + uint32_t c = m_digest[2]; + uint32_t d = m_digest[3]; + uint32_t e = m_digest[4]; + + for (std::size_t i=0; i<80; ++i) { + uint32_t f = 0; + uint32_t k = 0; + + if (i<20) { + f = (b & c) | (~b & d); + k = 0x5A827999; + } else if (i<40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i<60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + uint32_t temp = LeftRotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = LeftRotate(b, 30); + b = a; + a = temp; + } + + m_digest[0] += a; + m_digest[1] += b; + m_digest[2] += c; + m_digest[3] += d; + m_digest[4] += e; + } + private: + digest32_t m_digest; + uint8_t m_block[64]; + size_t m_blockByteIndex; + size_t m_byteCount; + }; +} +#endif diff --git a/include/crow.h b/include/crow.h index 5d99b91..00209c7 100644 --- a/include/crow.h +++ b/include/crow.h @@ -41,6 +41,12 @@ namespace crow { } + template + void handle_upgrade(const request& req, response& res, Adaptor&& adaptor) + { + router_.handle_upgrade(req, res, adaptor); + } + void handle(const request& req, response& res) { router_.handle(req, res); diff --git a/include/http_connection.h b/include/http_connection.h index 2bc6906..5517521 100644 --- a/include/http_connection.h +++ b/include/http_connection.h @@ -256,6 +256,7 @@ namespace crow req_ = std::move(parser_.to_request()); request& req = req_; + if (parser_.check_version(1, 0)) { // HTTP/1.0 @@ -282,6 +283,20 @@ namespace crow is_invalid_request = true; res = response(400); } + if (parser_.is_upgrade()) + { + if (req.get_header_value("upgrade") == "h2c") + { + // TODO HTTP/2 + // currently, ignore upgrade header + } + else + { + close_connection_ = true; + handler_->handle_upgrade(req, res, std::move(adaptor_)); + return; + } + } } CROW_LOG_INFO << "Request: " << boost::lexical_cast(adaptor_.remote_endpoint()) << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' ' @@ -296,6 +311,7 @@ namespace crow ctx_ = detail::context(); req.middleware_context = (void*)&ctx_; + req.io_service = &adaptor_.get_io_service(); detail::middleware_call_helper<0, decltype(ctx_), decltype(*middlewares_), Middlewares...>(*middlewares_, req, res, ctx_); if (!res.completed_) diff --git a/include/http_request.h b/include/http_request.h index ba1ff75..535a1fd 100644 --- a/include/http_request.h +++ b/include/http_request.h @@ -3,6 +3,7 @@ #include "common.h" #include "ci_map.h" #include "query_string.h" +#include namespace crow { @@ -17,6 +18,8 @@ namespace crow return empty; } + struct DetachHelper; + struct request { HTTPMethod method; @@ -27,6 +30,7 @@ namespace crow std::string body; void* middleware_context{}; + boost::asio::io_service* io_service{}; request() : method(HTTPMethod::Get) @@ -48,5 +52,17 @@ namespace crow return crow::get_header_value(headers, key); } + template + void post(CompletionHandler handler) + { + io_service->post(handler); + } + + template + void dispatch(CompletionHandler handler) + { + io_service->dispatch(handler); + } + }; } diff --git a/include/http_server.h b/include/http_server.h index 94f2fc3..addbbc1 100644 --- a/include/http_server.h +++ b/include/http_server.h @@ -99,7 +99,13 @@ namespace crow }; timer.async_wait(handler); - io_service_pool_[i]->run(); + try + { + io_service_pool_[i]->run(); + } catch(std::exception& e) + { + CROW_LOG_ERROR << "Worker Crash: An uncaught exception occurred: " << e.what(); + } })); CROW_LOG_INFO << server_name_ << " server is running, local port " << port_; diff --git a/include/mustache.h b/include/mustache.h index b596b45..279f356 100644 --- a/include/mustache.h +++ b/include/mustache.h @@ -520,7 +520,11 @@ namespace crow inline std::string default_loader(const std::string& filename) { - std::ifstream inf(detail::get_template_base_directory_ref() + filename); + std::string path = detail::get_template_base_directory_ref(); + if (!(path.back() == '/' || path.back() == '\\')) + path += '/'; + path += filename; + std::ifstream inf(path); if (!inf) return {}; return {std::istreambuf_iterator(inf), std::istreambuf_iterator()}; diff --git a/include/parser.h b/include/parser.h index f6b748b..b621850 100644 --- a/include/parser.h +++ b/include/parser.h @@ -143,6 +143,11 @@ namespace crow return request{(HTTPMethod)method, std::move(raw_url), std::move(url), std::move(url_params), std::move(headers), std::move(body)}; } + bool is_upgrade() const + { + return upgrade; + } + bool check_version(int major, int minor) const { return http_major == major && http_minor == minor; diff --git a/include/routing.h b/include/routing.h index 418209c..4fc2de8 100644 --- a/include/routing.h +++ b/include/routing.h @@ -13,6 +13,7 @@ #include "http_request.h" #include "utility.h" #include "logging.h" +#include "websocket.h" namespace crow { @@ -29,8 +30,26 @@ namespace crow } virtual void validate() = 0; + std::unique_ptr upgrade() + { + if (rule_to_upgrade_) + return std::move(rule_to_upgrade_); + return {}; + } virtual void handle(const request&, response&, const routing_params&) = 0; + virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&) + { + res = response(404); + res.end(); + } +#ifdef CROW_ENABLE_SSL + virtual void handle_upgrade(const request&, response& res, SSLAdaptor&&) + { + res = response(404); + res.end(); + } +#endif uint32_t get_methods() { @@ -42,6 +61,9 @@ namespace crow std::string rule_; std::string name_; + + std::unique_ptr rule_to_upgrade_; + friend class Router; template friend struct RuleParameterTraits; @@ -233,10 +255,82 @@ namespace crow } } + class WebSocketRule : public BaseRule + { + using self_t = WebSocketRule; + public: + WebSocketRule(std::string rule) + : BaseRule(std::move(rule)) + { + } + + void validate() override + { + } + + void handle(const request&, response& res, const routing_params&) override + { + res = response(404); + res.end(); + } + + void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override + { + new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_); + } +#ifdef CROW_ENABLE_SSL + void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override + { + new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_); + } +#endif + + template + self_t& onopen(Func f) + { + open_handler_ = f; + return *this; + } + + template + self_t& onmessage(Func f) + { + message_handler_ = f; + return *this; + } + + template + self_t& onclose(Func f) + { + close_handler_ = f; + return *this; + } + + template + self_t& onerror(Func f) + { + error_handler_ = f; + return *this; + } + + protected: + std::function open_handler_; + std::function message_handler_; + std::function close_handler_; + std::function error_handler_; + }; + template struct RuleParameterTraits { using self_t = T; + WebSocketRule& websocket() + { + auto p =new WebSocketRule(((self_t*)this)->rule_); + ((self_t*)this)->rule_to_upgrade_.reset(p); + return *p; + } + self_t& name(std::string name) noexcept { ((self_t*)this)->name_ = std::move(name); @@ -256,6 +350,7 @@ namespace crow ((self_t*)this)->methods_ |= 1 << (int)method; return (self_t&)*this; } + }; class DynamicRule : public BaseRule, public RuleParameterTraits @@ -343,7 +438,7 @@ namespace crow { } - void validate() + void validate() override { if (!handler_) { @@ -809,10 +904,80 @@ public: for(auto& rule:rules_) { if (rule) + { + auto upgraded = rule->upgrade(); + if (upgraded) + rule = std::move(upgraded); rule->validate(); + } } } + template + void handle_upgrade(const request& req, response& res, Adaptor&& adaptor) + { + auto found = trie_.find(req.url); + unsigned rule_index = found.first; + if (!rule_index) + { + CROW_LOG_DEBUG << "Cannot match rules " << req.url; + res = response(404); + res.end(); + return; + } + + if (rule_index >= rules_.size()) + throw std::runtime_error("Trie internal structure corrupted!"); + + if (rule_index == RULE_SPECIAL_REDIRECT_SLASH) + { + CROW_LOG_INFO << "Redirecting to a url with trailing slash: " << req.url; + res = response(301); + + // TODO absolute url building + if (req.get_header_value("Host").empty()) + { + res.add_header("Location", req.url + "/"); + } + else + { + res.add_header("Location", "http://" + req.get_header_value("Host") + req.url + "/"); + } + res.end(); + return; + } + + if ((rules_[rule_index]->get_methods() & (1<<(uint32_t)req.method)) == 0) + { + CROW_LOG_DEBUG << "Rule found but method mismatch: " << req.url << " with " << method_name(req.method) << "(" << (uint32_t)req.method << ") / " << rules_[rule_index]->get_methods(); + res = response(404); + res.end(); + return; + } + + CROW_LOG_DEBUG << "Matched rule (upgrade) '" << rules_[rule_index]->rule_ << "' " << (uint32_t)req.method << " / " << rules_[rule_index]->get_methods(); + + // any uncaught exceptions become 500s + try + { + rules_[rule_index]->handle_upgrade(req, res, std::move(adaptor)); + } + catch(std::exception& e) + { + CROW_LOG_ERROR << "An uncaught exception occurred: " << e.what(); + res = response(500); + res.end(); + return; + } + catch(...) + { + CROW_LOG_ERROR << "An uncaught exception occurred. The type was unknown so no information was available."; + res = response(500); + res.end(); + return; + } + } + void handle(const request& req, response& res) { auto found = trie_.find(req.url); diff --git a/include/settings.h b/include/settings.h index d8dfc9c..5c67f3b 100644 --- a/include/settings.h +++ b/include/settings.h @@ -8,7 +8,7 @@ /* #ifdef - enables logging */ #define CROW_ENABLE_LOGGING -/* #ifdef - enables SSL */ +/* #ifdef - enables ssl */ //#define CROW_ENABLE_SSL /* #define - specifies log level */ diff --git a/include/socket_adaptors.h b/include/socket_adaptors.h index 201360c..634bd4b 100644 --- a/include/socket_adaptors.h +++ b/include/socket_adaptors.h @@ -1,5 +1,8 @@ #pragma once #include +#ifdef CROW_ENABLE_SSL +#include +#endif #include "settings.h" namespace crow { @@ -14,6 +17,11 @@ namespace crow { } + boost::asio::io_service& get_io_service() + { + return socket_.get_io_service(); + } + tcp::socket& raw_socket() { return socket_; @@ -52,20 +60,21 @@ namespace crow struct SSLAdaptor { using context = boost::asio::ssl::context; + using ssl_socket_t = boost::asio::ssl::stream; SSLAdaptor(boost::asio::io_service& io_service, context* ctx) - : ssl_socket_(io_service, *ctx) + : ssl_socket_(new ssl_socket_t(io_service, *ctx)) { } boost::asio::ssl::stream& socket() { - return ssl_socket_; + return *ssl_socket_; } tcp::socket::lowest_layer_type& raw_socket() { - return ssl_socket_.lowest_layer(); + return ssl_socket_->lowest_layer(); } tcp::endpoint remote_endpoint() @@ -83,16 +92,21 @@ namespace crow raw_socket().close(); } + boost::asio::io_service& get_io_service() + { + return raw_socket().get_io_service(); + } + template void start(F f) { - ssl_socket_.async_handshake(boost::asio::ssl::stream_base::server, + ssl_socket_->async_handshake(boost::asio::ssl::stream_base::server, [f](const boost::system::error_code& ec) { f(ec); }); } - boost::asio::ssl::stream ssl_socket_; + std::unique_ptr> ssl_socket_; }; #endif } diff --git a/include/utility.h b/include/utility.h index 183d65b..fe9029e 100644 --- a/include/utility.h +++ b/include/utility.h @@ -499,5 +499,47 @@ template using arg = typename std::tuple_element>::type; }; + std::string base64encode(const char* data, size_t size, const char* key = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") + { + std::string ret; + ret.resize((size+2) / 3 * 4); + auto it = ret.begin(); + while(size >= 3) + { + *it++ = key[(((unsigned char)*data)&0xFC)>>2]; + unsigned char h = (((unsigned char)*data++) & 0x03) << 4; + *it++ = key[h|((((unsigned char)*data)&0xF0)>>4)]; + h = (((unsigned char)*data++) & 0x0F) << 2; + *it++ = key[h|((((unsigned char)*data)&0xC0)>>6)]; + *it++ = key[((unsigned char)*data++)&0x3F]; + + size -= 3; + } + if (size == 1) + { + *it++ = key[(((unsigned char)*data)&0xFC)>>2]; + unsigned char h = (((unsigned char)*data++) & 0x03) << 4; + *it++ = key[h]; + *it++ = '='; + *it++ = '='; + } + else if (size == 2) + { + *it++ = key[(((unsigned char)*data)&0xFC)>>2]; + unsigned char h = (((unsigned char)*data++) & 0x03) << 4; + *it++ = key[h|((((unsigned char)*data)&0xF0)>>4)]; + h = (((unsigned char)*data++) & 0x0F) << 2; + *it++ = key[h]; + *it++ = '='; + } + return ret; + } + + std::string base64encode_urlsafe(const char* data, size_t size) + { + return base64encode(data, size, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"); + } + + } // namespace utility } diff --git a/include/websocket.h b/include/websocket.h new file mode 100644 index 0000000..5299c1a --- /dev/null +++ b/include/websocket.h @@ -0,0 +1,482 @@ +#pragma once +#include "socket_adaptors.h" +#include "http_request.h" +#include "TinySHA1.hpp" + +namespace crow +{ + namespace websocket + { + enum class WebSocketReadState + { + MiniHeader, + Len16, + Len64, + Mask, + Payload, + }; + + struct connection + { + virtual void send_binary(const std::string& msg) = 0; + virtual void send_text(const std::string& msg) = 0; + virtual void close(const std::string& msg = "quit") = 0; + virtual ~connection(){} + }; + + template + class Connection : public connection + { + public: + Connection(const crow::request& req, Adaptor&& adaptor, + std::function open_handler, + std::function message_handler, + std::function close_handler, + std::function error_handler) + : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler)) + { + if (req.get_header_value("upgrade") != "websocket") + { + adaptor.close(); + delete this; + return; + } + // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== + // Sec-WebSocket-Version: 13 + std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + sha1::SHA1 s; + s.processBytes(magic.data(), magic.size()); + uint8_t digest[20]; + s.getDigestBytes(digest); + start(crow::utility::base64encode((char*)digest, 20)); + } + + template + void dispatch(CompletionHandler handler) + { + adaptor_.get_io_service().dispatch(handler); + } + + template + void post(CompletionHandler handler) + { + adaptor_.get_io_service().post(handler); + } + + void send_pong(const std::string& msg) + { + dispatch([this, msg]{ + char buf[3] = "\x8A\x00"; + buf[1] += msg.size(); + write_buffers_.emplace_back(buf, buf+2); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_binary(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(2, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_text(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(1, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void close(const std::string& msg) override + { + dispatch([this, msg]{ + has_sent_close_ = true; + if (has_recv_close_ && !is_close_handler_called_) + { + is_close_handler_called_ = true; + if (close_handler_) + close_handler_(*this, msg); + } + auto header = build_header(0x8, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + protected: + + std::string build_header(int opcode, size_t size) + { + char buf[2+8] = "\x80\x00"; + buf[0] += opcode; + if (size < 126) + { + buf[1] += size; + return {buf, buf+2}; + } + else if (size < 0x10000) + { + buf[1] += 126; + *(uint16_t*)(buf+2) = (uint16_t)size; + return {buf, buf+4}; + } + else + { + buf[1] += 127; + *(uint64_t*)(buf+2) = (uint64_t)size; + return {buf, buf+10}; + } + } + + void start(std::string&& hello) + { + static std::string header = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "; + static std::string crlf = "\r\n"; + write_buffers_.emplace_back(header); + write_buffers_.emplace_back(std::move(hello)); + write_buffers_.emplace_back(crlf); + write_buffers_.emplace_back(crlf); + do_write(); + if (open_handler_) + open_handler_(*this); + do_read(); + } + + void do_read() + { + is_reading = true; + switch(state_) + { + case WebSocketReadState::MiniHeader: + { + //boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1), + adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + mini_header_ = htons(mini_header_); +#ifdef CROW_ENABLE_DEBUG + + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?"); + } +#endif + + if (!ec && ((mini_header_ & 0x80) == 0x80)) + { + if ((mini_header_ & 0x7f) == 127) + { + state_ = WebSocketReadState::Len64; + } + else if ((mini_header_ & 0x7f) == 126) + { + state_ = WebSocketReadState::Len16; + } + else + { + remaining_length_ = mini_header_ & 0x7f; + state_ = WebSocketReadState::Mask; + } + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len16: + { + remaining_length_ = 0; + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 2), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length_ = ntohs(*(uint16_t*)&remaining_length_); +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len64: + { + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32)); +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 8) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Mask: + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 4) + { + throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Payload; + do_read(); + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + break; + case WebSocketReadState::Payload: + { + size_t to_read = buffer_.size(); + if (remaining_length_ < to_read) + to_read = remaining_length_; + adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + + if (!ec) + { + fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred); + remaining_length_ -= bytes_transferred; + if (remaining_length_ == 0) + { + handle_fragment(); + state_ = WebSocketReadState::MiniHeader; + do_read(); + } + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + } + break; + } + } + + bool is_FIN() + { + return mini_header_ & 0x8000; + } + + int opcode() + { + return (mini_header_ & 0x0f00) >> 8; + } + + void handle_fragment() + { + for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++) + { + fragment_[i] ^= ((char*)&mask_)[i%4]; + } + switch(opcode()) + { + case 0: // Continuation + { + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + case 1: // Text + { + is_binary_ = false; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 2: // Binary + { + is_binary_ = true; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 0x8: // Close + { + has_recv_close_ = true; + if (!has_sent_close_) + { + close(fragment_); + } + else + { + adaptor_.close(); + close_connection_ = true; + if (!is_close_handler_called_) + { + if (close_handler_) + close_handler_(*this, fragment_); + is_close_handler_called_ = true; + } + check_destroy(); + } + } + break; + case 0x9: // Ping + { + send_pong(fragment_); + } + break; + case 0xA: // Pong + { + pong_received_ = true; + } + break; + } + + fragment_.clear(); + } + + void do_write() + { + if (sending_buffers_.empty()) + { + sending_buffers_.swap(write_buffers_); + std::vector buffers; + buffers.reserve(sending_buffers_.size()); + for(auto& s:sending_buffers_) + { + buffers.emplace_back(boost::asio::buffer(s)); + } + boost::asio::async_write(adaptor_.socket(), buffers, + [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/) + { + sending_buffers_.clear(); + if (!ec && !close_connection_) + { + if (!write_buffers_.empty()) + do_write(); + if (has_sent_close_) + close_connection_ = true; + } + else + { + close_connection_ = true; + check_destroy(); + } + }); + } + } + + void check_destroy() + { + //if (has_sent_close_ && has_recv_close_) + if (!is_close_handler_called_) + if (close_handler_) + close_handler_(*this, "uncleanly"); + if (sending_buffers_.empty() && !is_reading) + delete this; + } + private: + Adaptor adaptor_; + + std::vector sending_buffers_; + std::vector write_buffers_; + + boost::array buffer_; + bool is_binary_; + std::string message_; + std::string fragment_; + WebSocketReadState state_{WebSocketReadState::MiniHeader}; + uint64_t remaining_length_{0}; + bool close_connection_{false}; + bool is_reading{false}; + uint32_t mask_; + uint16_t mini_header_; + bool has_sent_close_{false}; + bool has_recv_close_{false}; + bool error_occured_{false}; + bool pong_received_{false}; + bool is_close_handler_called_{false}; + + std::function open_handler_; + std::function message_handler_; + std::function close_handler_; + std::function error_handler_; + }; + } +} -- cgit v1.2.3-54-g00ecf