aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoripknHama <ipknhama@gmail.com>2016-08-28 14:46:31 +0900
committeripknHama <ipknhama@gmail.com>2016-08-28 14:46:31 +0900
commit967adf0de55afcb52881cdb1a7b16788c7c283db (patch)
treedbe4fe620a136bdb462a4ad29e83d6d699b3b447
parent45f6d12fd382662675000fb1c60909287733127c (diff)
downloadcrow-967adf0de55afcb52881cdb1a7b16788c7c283db.tar.gz
crow-967adf0de55afcb52881cdb1a7b16788c7c283db.zip
Add websocket feature
-rw-r--r--README.md39
-rw-r--r--examples/CMakeLists.txt4
-rw-r--r--examples/websocket/example_ws.cpp45
-rw-r--r--examples/websocket/templates/ws.html41
-rw-r--r--include/TinySHA1.hpp196
-rw-r--r--include/crow.h6
-rw-r--r--include/http_connection.h16
-rw-r--r--include/http_request.h16
-rw-r--r--include/http_server.h8
-rw-r--r--include/mustache.h6
-rw-r--r--include/parser.h5
-rw-r--r--include/routing.h167
-rw-r--r--include/settings.h2
-rw-r--r--include/socket_adaptors.h24
-rw-r--r--include/utility.h42
-rw-r--r--include/websocket.h482
16 files changed, 1090 insertions, 9 deletions
diff --git a/README.md b/README.md
index 60a9af5..a5daf1d 100644
--- a/README.md
+++ b/README.md
@@ -127,6 +127,35 @@ ctest
Crow uses the following libraries.
+ http-parser
+
+ https://github.com/nodejs/http-parser
+
+ http_parser.c is based on src/http/ngx_http_parse.c from NGINX copyright
+ Igor Sysoev.
+
+ Additional changes are licensed under the same terms as NGINX and
+ copyright Joyent, Inc. and other Node contributors. All rights reserved.
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to
+ deal in the Software without restriction, including without limitation the
+ rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ sell copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in
+ all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ IN THE SOFTWARE.
+
+
qs_parse
https://github.com/bartgrantham/qs_parse
@@ -141,3 +170,13 @@ Crow uses the following libraries.
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
+
+ TinySHA1
+
+ https://github.com/mohaps/TinySHA1
+
+ TinySHA1 - a header only implementation of the SHA1 algorithm. Based on the implementation in boost::uuid::details
+
+ Copyright (c) 2012-22 SAURAV MOHAPATRA mohaps@gmail.com
+ Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies.
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 22139ad..1e96dea 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -15,6 +15,10 @@ add_executable(example_ssl ssl/example_ssl.cpp)
target_link_libraries(example_ssl ${Boost_LIBRARIES})
target_link_libraries(example_ssl ${CMAKE_THREAD_LIBS_INIT} ssl crypto)
+add_executable(example_websocket websocket/example_ws.cpp)
+target_link_libraries(example_websocket ${Boost_LIBRARIES})
+target_link_libraries(example_websocket ${CMAKE_THREAD_LIBS_INIT} ssl crypto)
+
add_executable(example example.cpp)
#target_link_libraries(example crow)
target_link_libraries(example ${Boost_LIBRARIES})
diff --git a/examples/websocket/example_ws.cpp b/examples/websocket/example_ws.cpp
new file mode 100644
index 0000000..7fbd5ef
--- /dev/null
+++ b/examples/websocket/example_ws.cpp
@@ -0,0 +1,45 @@
+#include "crow.h"
+#include "mustache.h"
+#include "websocket.h"
+#include <unordered_set>
+#include <mutex>
+
+
+int main()
+{
+ crow::SimpleApp app;
+
+ std::mutex mtx;;
+ std::unordered_set<crow::websocket::connection*> users;
+
+ CROW_ROUTE(app, "/ws")
+ .websocket()
+ .onopen([&](crow::websocket::connection& conn){
+ CROW_LOG_INFO << "new websocket connection";
+ std::lock_guard<std::mutex> _(mtx);
+ users.insert(&conn);
+ })
+ .onclose([&](crow::websocket::connection& conn, const std::string& reason){
+ CROW_LOG_INFO << "websocket connection closed: " << reason;
+ std::lock_guard<std::mutex> _(mtx);
+ users.erase(&conn);
+ })
+ .onmessage([&](crow::websocket::connection& /*conn*/, const std::string& data, bool is_binary){
+ std::lock_guard<std::mutex> _(mtx);
+ for(auto u:users)
+ if (is_binary)
+ u->send_binary(data);
+ else
+ u->send_text(data);
+ });
+
+ CROW_ROUTE(app, "/")
+ ([]{
+ auto page = crow::mustache::load("ws.html");
+ return page.render();
+ });
+
+ app.port(40080)
+ .multithreaded()
+ .run();
+}
diff --git a/examples/websocket/templates/ws.html b/examples/websocket/templates/ws.html
new file mode 100644
index 0000000..f6e7281
--- /dev/null
+++ b/examples/websocket/templates/ws.html
@@ -0,0 +1,41 @@
+<!doctype html>
+<html>
+<head>
+ <script src="https://code.jquery.com/jquery-3.1.0.min.js"></script>
+</head>
+<body>
+ <input id="msg" type="text"></input>
+ <button id="send">
+ Send
+ </button><BR>
+ <textarea id="log" cols=100 rows=50>
+ </textarea>
+ <script>
+var sock = new WebSocket("ws://i.ipkn.me:40080/ws");
+sock.onopen = ()=>{
+ console.log('open')
+}
+sock.onerror = (e)=>{
+ console.log('error',e)
+}
+sock.onclose = ()=>{
+ console.log('close')
+}
+sock.onmessage = (e)=>{
+ $("#log").val(
+ e.data +"\n" + $("#log").val());
+}
+$("#msg").keypress(function(e){
+ if (e.which == 13)
+ {
+ sock.send($("#msg").val());
+ $("#msg").val("");
+ }
+});
+$("#send").click(()=>{
+ sock.send($("#msg").val());
+ $("#msg").val("");
+});
+ </script>
+</body>
+</html>
diff --git a/include/TinySHA1.hpp b/include/TinySHA1.hpp
new file mode 100644
index 0000000..70af046
--- /dev/null
+++ b/include/TinySHA1.hpp
@@ -0,0 +1,196 @@
+/*
+ *
+ * TinySHA1 - a header only implementation of the SHA1 algorithm in C++. Based
+ * on the implementation in boost::uuid::details.
+ *
+ * SHA1 Wikipedia Page: http://en.wikipedia.org/wiki/SHA-1
+ *
+ * Copyright (c) 2012-22 SAURAV MOHAPATRA <mohaps@gmail.com>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+#ifndef _TINY_SHA1_HPP_
+#define _TINY_SHA1_HPP_
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <stdint.h>
+namespace sha1
+{
+ class SHA1
+ {
+ public:
+ typedef uint32_t digest32_t[5];
+ typedef uint8_t digest8_t[20];
+ inline static uint32_t LeftRotate(uint32_t value, size_t count) {
+ return (value << count) ^ (value >> (32-count));
+ }
+ SHA1(){ reset(); }
+ virtual ~SHA1() {}
+ SHA1(const SHA1& s) { *this = s; }
+ const SHA1& operator = (const SHA1& s) {
+ memcpy(m_digest, s.m_digest, 5 * sizeof(uint32_t));
+ memcpy(m_block, s.m_block, 64);
+ m_blockByteIndex = s.m_blockByteIndex;
+ m_byteCount = s.m_byteCount;
+ return *this;
+ }
+ SHA1& reset() {
+ m_digest[0] = 0x67452301;
+ m_digest[1] = 0xEFCDAB89;
+ m_digest[2] = 0x98BADCFE;
+ m_digest[3] = 0x10325476;
+ m_digest[4] = 0xC3D2E1F0;
+ m_blockByteIndex = 0;
+ m_byteCount = 0;
+ return *this;
+ }
+ SHA1& processByte(uint8_t octet) {
+ this->m_block[this->m_blockByteIndex++] = octet;
+ ++this->m_byteCount;
+ if(m_blockByteIndex == 64) {
+ this->m_blockByteIndex = 0;
+ processBlock();
+ }
+ return *this;
+ }
+ SHA1& processBlock(const void* const start, const void* const end) {
+ const uint8_t* begin = static_cast<const uint8_t*>(start);
+ const uint8_t* finish = static_cast<const uint8_t*>(end);
+ while(begin != finish) {
+ processByte(*begin);
+ begin++;
+ }
+ return *this;
+ }
+ SHA1& processBytes(const void* const data, size_t len) {
+ const uint8_t* block = static_cast<const uint8_t*>(data);
+ processBlock(block, block + len);
+ return *this;
+ }
+ const uint32_t* getDigest(digest32_t digest) {
+ size_t bitCount = this->m_byteCount * 8;
+ processByte(0x80);
+ if (this->m_blockByteIndex > 56) {
+ while (m_blockByteIndex != 0) {
+ processByte(0);
+ }
+ while (m_blockByteIndex < 56) {
+ processByte(0);
+ }
+ } else {
+ while (m_blockByteIndex < 56) {
+ processByte(0);
+ }
+ }
+ processByte(0);
+ processByte(0);
+ processByte(0);
+ processByte(0);
+ processByte( static_cast<unsigned char>((bitCount>>24) & 0xFF));
+ processByte( static_cast<unsigned char>((bitCount>>16) & 0xFF));
+ processByte( static_cast<unsigned char>((bitCount>>8 ) & 0xFF));
+ processByte( static_cast<unsigned char>((bitCount) & 0xFF));
+
+ memcpy(digest, m_digest, 5 * sizeof(uint32_t));
+ return digest;
+ }
+ const uint8_t* getDigestBytes(digest8_t digest) {
+ digest32_t d32;
+ getDigest(d32);
+ size_t di = 0;
+ digest[di++] = ((d32[0] >> 24) & 0xFF);
+ digest[di++] = ((d32[0] >> 16) & 0xFF);
+ digest[di++] = ((d32[0] >> 8) & 0xFF);
+ digest[di++] = ((d32[0]) & 0xFF);
+
+ digest[di++] = ((d32[1] >> 24) & 0xFF);
+ digest[di++] = ((d32[1] >> 16) & 0xFF);
+ digest[di++] = ((d32[1] >> 8) & 0xFF);
+ digest[di++] = ((d32[1]) & 0xFF);
+
+ digest[di++] = ((d32[2] >> 24) & 0xFF);
+ digest[di++] = ((d32[2] >> 16) & 0xFF);
+ digest[di++] = ((d32[2] >> 8) & 0xFF);
+ digest[di++] = ((d32[2]) & 0xFF);
+
+ digest[di++] = ((d32[3] >> 24) & 0xFF);
+ digest[di++] = ((d32[3] >> 16) & 0xFF);
+ digest[di++] = ((d32[3] >> 8) & 0xFF);
+ digest[di++] = ((d32[3]) & 0xFF);
+
+ digest[di++] = ((d32[4] >> 24) & 0xFF);
+ digest[di++] = ((d32[4] >> 16) & 0xFF);
+ digest[di++] = ((d32[4] >> 8) & 0xFF);
+ digest[di++] = ((d32[4]) & 0xFF);
+ return digest;
+ }
+
+ protected:
+ void processBlock() {
+ uint32_t w[80];
+ for (size_t i = 0; i < 16; i++) {
+ w[i] = (m_block[i*4 + 0] << 24);
+ w[i] |= (m_block[i*4 + 1] << 16);
+ w[i] |= (m_block[i*4 + 2] << 8);
+ w[i] |= (m_block[i*4 + 3]);
+ }
+ for (size_t i = 16; i < 80; i++) {
+ w[i] = LeftRotate((w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]), 1);
+ }
+
+ uint32_t a = m_digest[0];
+ uint32_t b = m_digest[1];
+ uint32_t c = m_digest[2];
+ uint32_t d = m_digest[3];
+ uint32_t e = m_digest[4];
+
+ for (std::size_t i=0; i<80; ++i) {
+ uint32_t f = 0;
+ uint32_t k = 0;
+
+ if (i<20) {
+ f = (b & c) | (~b & d);
+ k = 0x5A827999;
+ } else if (i<40) {
+ f = b ^ c ^ d;
+ k = 0x6ED9EBA1;
+ } else if (i<60) {
+ f = (b & c) | (b & d) | (c & d);
+ k = 0x8F1BBCDC;
+ } else {
+ f = b ^ c ^ d;
+ k = 0xCA62C1D6;
+ }
+ uint32_t temp = LeftRotate(a, 5) + f + e + k + w[i];
+ e = d;
+ d = c;
+ c = LeftRotate(b, 30);
+ b = a;
+ a = temp;
+ }
+
+ m_digest[0] += a;
+ m_digest[1] += b;
+ m_digest[2] += c;
+ m_digest[3] += d;
+ m_digest[4] += e;
+ }
+ private:
+ digest32_t m_digest;
+ uint8_t m_block[64];
+ size_t m_blockByteIndex;
+ size_t m_byteCount;
+ };
+}
+#endif
diff --git a/include/crow.h b/include/crow.h
index 5d99b91..00209c7 100644
--- a/include/crow.h
+++ b/include/crow.h
@@ -41,6 +41,12 @@ namespace crow
{
}
+ template <typename Adaptor>
+ void handle_upgrade(const request& req, response& res, Adaptor&& adaptor)
+ {
+ router_.handle_upgrade(req, res, adaptor);
+ }
+
void handle(const request& req, response& res)
{
router_.handle(req, res);
diff --git a/include/http_connection.h b/include/http_connection.h
index 2bc6906..5517521 100644
--- a/include/http_connection.h
+++ b/include/http_connection.h
@@ -256,6 +256,7 @@ namespace crow
req_ = std::move(parser_.to_request());
request& req = req_;
+
if (parser_.check_version(1, 0))
{
// HTTP/1.0
@@ -282,6 +283,20 @@ namespace crow
is_invalid_request = true;
res = response(400);
}
+ if (parser_.is_upgrade())
+ {
+ if (req.get_header_value("upgrade") == "h2c")
+ {
+ // TODO HTTP/2
+ // currently, ignore upgrade header
+ }
+ else
+ {
+ close_connection_ = true;
+ handler_->handle_upgrade(req, res, std::move(adaptor_));
+ return;
+ }
+ }
}
CROW_LOG_INFO << "Request: " << boost::lexical_cast<std::string>(adaptor_.remote_endpoint()) << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' '
@@ -296,6 +311,7 @@ namespace crow
ctx_ = detail::context<Middlewares...>();
req.middleware_context = (void*)&ctx_;
+ req.io_service = &adaptor_.get_io_service();
detail::middleware_call_helper<0, decltype(ctx_), decltype(*middlewares_), Middlewares...>(*middlewares_, req, res, ctx_);
if (!res.completed_)
diff --git a/include/http_request.h b/include/http_request.h
index ba1ff75..535a1fd 100644
--- a/include/http_request.h
+++ b/include/http_request.h
@@ -3,6 +3,7 @@
#include "common.h"
#include "ci_map.h"
#include "query_string.h"
+#include <boost/asio.hpp>
namespace crow
{
@@ -17,6 +18,8 @@ namespace crow
return empty;
}
+ struct DetachHelper;
+
struct request
{
HTTPMethod method;
@@ -27,6 +30,7 @@ namespace crow
std::string body;
void* middleware_context{};
+ boost::asio::io_service* io_service{};
request()
: method(HTTPMethod::Get)
@@ -48,5 +52,17 @@ namespace crow
return crow::get_header_value(headers, key);
}
+ template<typename CompletionHandler>
+ void post(CompletionHandler handler)
+ {
+ io_service->post(handler);
+ }
+
+ template<typename CompletionHandler>
+ void dispatch(CompletionHandler handler)
+ {
+ io_service->dispatch(handler);
+ }
+
};
}
diff --git a/include/http_server.h b/include/http_server.h
index 94f2fc3..addbbc1 100644
--- a/include/http_server.h
+++ b/include/http_server.h
@@ -99,7 +99,13 @@ namespace crow
};
timer.async_wait(handler);
- io_service_pool_[i]->run();
+ try
+ {
+ io_service_pool_[i]->run();
+ } catch(std::exception& e)
+ {
+ CROW_LOG_ERROR << "Worker Crash: An uncaught exception occurred: " << e.what();
+ }
}));
CROW_LOG_INFO << server_name_ << " server is running, local port " << port_;
diff --git a/include/mustache.h b/include/mustache.h
index b596b45..279f356 100644
--- a/include/mustache.h
+++ b/include/mustache.h
@@ -520,7 +520,11 @@ namespace crow
inline std::string default_loader(const std::string& filename)
{
- std::ifstream inf(detail::get_template_base_directory_ref() + filename);
+ std::string path = detail::get_template_base_directory_ref();
+ if (!(path.back() == '/' || path.back() == '\\'))
+ path += '/';
+ path += filename;
+ std::ifstream inf(path);
if (!inf)
return {};
return {std::istreambuf_iterator<char>(inf), std::istreambuf_iterator<char>()};
diff --git a/include/parser.h b/include/parser.h
index f6b748b..b621850 100644
--- a/include/parser.h
+++ b/include/parser.h
@@ -143,6 +143,11 @@ namespace crow
return request{(HTTPMethod)method, std::move(raw_url), std::move(url), std::move(url_params), std::move(headers), std::move(body)};
}
+ bool is_upgrade() const
+ {
+ return upgrade;
+ }
+
bool check_version(int major, int minor) const
{
return http_major == major && http_minor == minor;
diff --git a/include/routing.h b/include/routing.h
index 418209c..4fc2de8 100644
--- a/include/routing.h
+++ b/include/routing.h
@@ -13,6 +13,7 @@
#include "http_request.h"
#include "utility.h"
#include "logging.h"
+#include "websocket.h"
namespace crow
{
@@ -29,8 +30,26 @@ namespace crow
}
virtual void validate() = 0;
+ std::unique_ptr<BaseRule> upgrade()
+ {
+ if (rule_to_upgrade_)
+ return std::move(rule_to_upgrade_);
+ return {};
+ }
virtual void handle(const request&, response&, const routing_params&) = 0;
+ virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&)
+ {
+ res = response(404);
+ res.end();
+ }
+#ifdef CROW_ENABLE_SSL
+ virtual void handle_upgrade(const request&, response& res, SSLAdaptor&&)
+ {
+ res = response(404);
+ res.end();
+ }
+#endif
uint32_t get_methods()
{
@@ -42,6 +61,9 @@ namespace crow
std::string rule_;
std::string name_;
+
+ std::unique_ptr<BaseRule> rule_to_upgrade_;
+
friend class Router;
template <typename T>
friend struct RuleParameterTraits;
@@ -233,10 +255,82 @@ namespace crow
}
}
+ class WebSocketRule : public BaseRule
+ {
+ using self_t = WebSocketRule;
+ public:
+ WebSocketRule(std::string rule)
+ : BaseRule(std::move(rule))
+ {
+ }
+
+ void validate() override
+ {
+ }
+
+ void handle(const request&, response& res, const routing_params&) override
+ {
+ res = response(404);
+ res.end();
+ }
+
+ void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
+ {
+ new crow::websocket::Connection<SocketAdaptor>(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_);
+ }
+#ifdef CROW_ENABLE_SSL
+ void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
+ {
+ new crow::websocket::Connection<SSLAdaptor>(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_);
+ }
+#endif
+
+ template <typename Func>
+ self_t& onopen(Func f)
+ {
+ open_handler_ = f;
+ return *this;
+ }
+
+ template <typename Func>
+ self_t& onmessage(Func f)
+ {
+ message_handler_ = f;
+ return *this;
+ }
+
+ template <typename Func>
+ self_t& onclose(Func f)
+ {
+ close_handler_ = f;
+ return *this;
+ }
+
+ template <typename Func>
+ self_t& onerror(Func f)
+ {
+ error_handler_ = f;
+ return *this;
+ }
+
+ protected:
+ std::function<void(crow::websocket::connection&)> open_handler_;
+ std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
+ std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
+ std::function<void(crow::websocket::connection&)> error_handler_;
+ };
+
template <typename T>
struct RuleParameterTraits
{
using self_t = T;
+ WebSocketRule& websocket()
+ {
+ auto p =new WebSocketRule(((self_t*)this)->rule_);
+ ((self_t*)this)->rule_to_upgrade_.reset(p);
+ return *p;
+ }
+
self_t& name(std::string name) noexcept
{
((self_t*)this)->name_ = std::move(name);
@@ -256,6 +350,7 @@ namespace crow
((self_t*)this)->methods_ |= 1 << (int)method;
return (self_t&)*this;
}
+
};
class DynamicRule : public BaseRule, public RuleParameterTraits<DynamicRule>
@@ -343,7 +438,7 @@ namespace crow
{
}
- void validate()
+ void validate() override
{
if (!handler_)
{
@@ -809,10 +904,80 @@ public:
for(auto& rule:rules_)
{
if (rule)
+ {
+ auto upgraded = rule->upgrade();
+ if (upgraded)
+ rule = std::move(upgraded);
rule->validate();
+ }
}
}
+ template <typename Adaptor>
+ void handle_upgrade(const request& req, response& res, Adaptor&& adaptor)
+ {
+ auto found = trie_.find(req.url);
+ unsigned rule_index = found.first;
+ if (!rule_index)
+ {
+ CROW_LOG_DEBUG << "Cannot match rules " << req.url;
+ res = response(404);
+ res.end();
+ return;
+ }
+
+ if (rule_index >= rules_.size())
+ throw std::runtime_error("Trie internal structure corrupted!");
+
+ if (rule_index == RULE_SPECIAL_REDIRECT_SLASH)
+ {
+ CROW_LOG_INFO << "Redirecting to a url with trailing slash: " << req.url;
+ res = response(301);
+
+ // TODO absolute url building
+ if (req.get_header_value("Host").empty())
+ {
+ res.add_header("Location", req.url + "/");
+ }
+ else
+ {
+ res.add_header("Location", "http://" + req.get_header_value("Host") + req.url + "/");
+ }
+ res.end();
+ return;
+ }
+
+ if ((rules_[rule_index]->get_methods() & (1<<(uint32_t)req.method)) == 0)
+ {
+ CROW_LOG_DEBUG << "Rule found but method mismatch: " << req.url << " with " << method_name(req.method) << "(" << (uint32_t)req.method << ") / " << rules_[rule_index]->get_methods();
+ res = response(404);
+ res.end();
+ return;
+ }
+
+ CROW_LOG_DEBUG << "Matched rule (upgrade) '" << rules_[rule_index]->rule_ << "' " << (uint32_t)req.method << " / " << rules_[rule_index]->get_methods();
+
+ // any uncaught exceptions become 500s
+ try
+ {
+ rules_[rule_index]->handle_upgrade(req, res, std::move(adaptor));
+ }
+ catch(std::exception& e)
+ {
+ CROW_LOG_ERROR << "An uncaught exception occurred: " << e.what();
+ res = response(500);
+ res.end();
+ return;
+ }
+ catch(...)
+ {
+ CROW_LOG_ERROR << "An uncaught exception occurred. The type was unknown so no information was available.";
+ res = response(500);
+ res.end();
+ return;
+ }
+ }
+
void handle(const request& req, response& res)
{
auto found = trie_.find(req.url);
diff --git a/include/settings.h b/include/settings.h
index d8dfc9c..5c67f3b 100644
--- a/include/settings.h
+++ b/include/settings.h
@@ -8,7 +8,7 @@
/* #ifdef - enables logging */
#define CROW_ENABLE_LOGGING
-/* #ifdef - enables SSL */
+/* #ifdef - enables ssl */
//#define CROW_ENABLE_SSL
/* #define - specifies log level */
diff --git a/include/socket_adaptors.h b/include/socket_adaptors.h
index 201360c..634bd4b 100644
--- a/include/socket_adaptors.h
+++ b/include/socket_adaptors.h
@@ -1,5 +1,8 @@
#pragma once
#include <boost/asio.hpp>
+#ifdef CROW_ENABLE_SSL
+#include <boost/asio/ssl.hpp>
+#endif
#include "settings.h"
namespace crow
{
@@ -14,6 +17,11 @@ namespace crow
{
}
+ boost::asio::io_service& get_io_service()
+ {
+ return socket_.get_io_service();
+ }
+
tcp::socket& raw_socket()
{
return socket_;
@@ -52,20 +60,21 @@ namespace crow
struct SSLAdaptor
{
using context = boost::asio::ssl::context;
+ using ssl_socket_t = boost::asio::ssl::stream<tcp::socket>;
SSLAdaptor(boost::asio::io_service& io_service, context* ctx)
- : ssl_socket_(io_service, *ctx)
+ : ssl_socket_(new ssl_socket_t(io_service, *ctx))
{
}
boost::asio::ssl::stream<tcp::socket>& socket()
{
- return ssl_socket_;
+ return *ssl_socket_;
}
tcp::socket::lowest_layer_type&
raw_socket()
{
- return ssl_socket_.lowest_layer();
+ return ssl_socket_->lowest_layer();
}
tcp::endpoint remote_endpoint()
@@ -83,16 +92,21 @@ namespace crow
raw_socket().close();
}
+ boost::asio::io_service& get_io_service()
+ {
+ return raw_socket().get_io_service();
+ }
+
template <typename F>
void start(F f)
{
- ssl_socket_.async_handshake(boost::asio::ssl::stream_base::server,
+ ssl_socket_->async_handshake(boost::asio::ssl::stream_base::server,
[f](const boost::system::error_code& ec) {
f(ec);
});
}
- boost::asio::ssl::stream<tcp::socket> ssl_socket_;
+ std::unique_ptr<boost::asio::ssl::stream<tcp::socket>> ssl_socket_;
};
#endif
}
diff --git a/include/utility.h b/include/utility.h
index 183d65b..fe9029e 100644
--- a/include/utility.h
+++ b/include/utility.h
@@ -499,5 +499,47 @@ template <typename F, typename Set>
using arg = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
+ std::string base64encode(const char* data, size_t size, const char* key = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
+ {
+ std::string ret;
+ ret.resize((size+2) / 3 * 4);
+ auto it = ret.begin();
+ while(size >= 3)
+ {
+ *it++ = key[(((unsigned char)*data)&0xFC)>>2];
+ unsigned char h = (((unsigned char)*data++) & 0x03) << 4;
+ *it++ = key[h|((((unsigned char)*data)&0xF0)>>4)];
+ h = (((unsigned char)*data++) & 0x0F) << 2;
+ *it++ = key[h|((((unsigned char)*data)&0xC0)>>6)];
+ *it++ = key[((unsigned char)*data++)&0x3F];
+
+ size -= 3;
+ }
+ if (size == 1)
+ {
+ *it++ = key[(((unsigned char)*data)&0xFC)>>2];
+ unsigned char h = (((unsigned char)*data++) & 0x03) << 4;
+ *it++ = key[h];
+ *it++ = '=';
+ *it++ = '=';
+ }
+ else if (size == 2)
+ {
+ *it++ = key[(((unsigned char)*data)&0xFC)>>2];
+ unsigned char h = (((unsigned char)*data++) & 0x03) << 4;
+ *it++ = key[h|((((unsigned char)*data)&0xF0)>>4)];
+ h = (((unsigned char)*data++) & 0x0F) << 2;
+ *it++ = key[h];
+ *it++ = '=';
+ }
+ return ret;
+ }
+
+ std::string base64encode_urlsafe(const char* data, size_t size)
+ {
+ return base64encode(data, size, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
+ }
+
+
} // namespace utility
}
diff --git a/include/websocket.h b/include/websocket.h
new file mode 100644
index 0000000..5299c1a
--- /dev/null
+++ b/include/websocket.h
@@ -0,0 +1,482 @@
+#pragma once
+#include "socket_adaptors.h"
+#include "http_request.h"
+#include "TinySHA1.hpp"
+
+namespace crow
+{
+ namespace websocket
+ {
+ enum class WebSocketReadState
+ {
+ MiniHeader,
+ Len16,
+ Len64,
+ Mask,
+ Payload,
+ };
+
+ struct connection
+ {
+ virtual void send_binary(const std::string& msg) = 0;
+ virtual void send_text(const std::string& msg) = 0;
+ virtual void close(const std::string& msg = "quit") = 0;
+ virtual ~connection(){}
+ };
+
+ template <typename Adaptor>
+ class Connection : public connection
+ {
+ public:
+ Connection(const crow::request& req, Adaptor&& adaptor,
+ std::function<void(crow::websocket::connection&)> open_handler,
+ std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
+ std::function<void(crow::websocket::connection&, const std::string&)> close_handler,
+ std::function<void(crow::websocket::connection&)> error_handler)
+ : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler))
+ {
+ if (req.get_header_value("upgrade") != "websocket")
+ {
+ adaptor.close();
+ delete this;
+ return;
+ }
+ // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+ // Sec-WebSocket-Version: 13
+ std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+ sha1::SHA1 s;
+ s.processBytes(magic.data(), magic.size());
+ uint8_t digest[20];
+ s.getDigestBytes(digest);
+ start(crow::utility::base64encode((char*)digest, 20));
+ }
+
+ template<typename CompletionHandler>
+ void dispatch(CompletionHandler handler)
+ {
+ adaptor_.get_io_service().dispatch(handler);
+ }
+
+ template<typename CompletionHandler>
+ void post(CompletionHandler handler)
+ {
+ adaptor_.get_io_service().post(handler);
+ }
+
+ void send_pong(const std::string& msg)
+ {
+ dispatch([this, msg]{
+ char buf[3] = "\x8A\x00";
+ buf[1] += msg.size();
+ write_buffers_.emplace_back(buf, buf+2);
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ void send_binary(const std::string& msg) override
+ {
+ dispatch([this, msg]{
+ auto header = build_header(2, msg.size());
+ write_buffers_.emplace_back(std::move(header));
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ void send_text(const std::string& msg) override
+ {
+ dispatch([this, msg]{
+ auto header = build_header(1, msg.size());
+ write_buffers_.emplace_back(std::move(header));
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ void close(const std::string& msg) override
+ {
+ dispatch([this, msg]{
+ has_sent_close_ = true;
+ if (has_recv_close_ && !is_close_handler_called_)
+ {
+ is_close_handler_called_ = true;
+ if (close_handler_)
+ close_handler_(*this, msg);
+ }
+ auto header = build_header(0x8, msg.size());
+ write_buffers_.emplace_back(std::move(header));
+ write_buffers_.emplace_back(msg);
+ do_write();
+ });
+ }
+
+ protected:
+
+ std::string build_header(int opcode, size_t size)
+ {
+ char buf[2+8] = "\x80\x00";
+ buf[0] += opcode;
+ if (size < 126)
+ {
+ buf[1] += size;
+ return {buf, buf+2};
+ }
+ else if (size < 0x10000)
+ {
+ buf[1] += 126;
+ *(uint16_t*)(buf+2) = (uint16_t)size;
+ return {buf, buf+4};
+ }
+ else
+ {
+ buf[1] += 127;
+ *(uint64_t*)(buf+2) = (uint64_t)size;
+ return {buf, buf+10};
+ }
+ }
+
+ void start(std::string&& hello)
+ {
+ static std::string header = "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: ";
+ static std::string crlf = "\r\n";
+ write_buffers_.emplace_back(header);
+ write_buffers_.emplace_back(std::move(hello));
+ write_buffers_.emplace_back(crlf);
+ write_buffers_.emplace_back(crlf);
+ do_write();
+ if (open_handler_)
+ open_handler_(*this);
+ do_read();
+ }
+
+ void do_read()
+ {
+ is_reading = true;
+ switch(state_)
+ {
+ case WebSocketReadState::MiniHeader:
+ {
+ //boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1),
+ adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+ mini_header_ = htons(mini_header_);
+#ifdef CROW_ENABLE_DEBUG
+
+ if (!ec && bytes_transferred != 2)
+ {
+ throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec && ((mini_header_ & 0x80) == 0x80))
+ {
+ if ((mini_header_ & 0x7f) == 127)
+ {
+ state_ = WebSocketReadState::Len64;
+ }
+ else if ((mini_header_ & 0x7f) == 126)
+ {
+ state_ = WebSocketReadState::Len16;
+ }
+ else
+ {
+ remaining_length_ = mini_header_ & 0x7f;
+ state_ = WebSocketReadState::Mask;
+ }
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ adaptor_.close();
+ if (error_handler_)
+ error_handler_(*this);
+ check_destroy();
+ }
+ });
+ }
+ break;
+ case WebSocketReadState::Len16:
+ {
+ remaining_length_ = 0;
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 2),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+ remaining_length_ = ntohs(*(uint16_t*)&remaining_length_);
+#ifdef CROW_ENABLE_DEBUG
+ if (!ec && bytes_transferred != 2)
+ {
+ throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec)
+ {
+ state_ = WebSocketReadState::Mask;
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ adaptor_.close();
+ if (error_handler_)
+ error_handler_(*this);
+ check_destroy();
+ }
+ });
+ }
+ break;
+ case WebSocketReadState::Len64:
+ {
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+ remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32));
+#ifdef CROW_ENABLE_DEBUG
+ if (!ec && bytes_transferred != 8)
+ {
+ throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec)
+ {
+ state_ = WebSocketReadState::Mask;
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ adaptor_.close();
+ if (error_handler_)
+ error_handler_(*this);
+ check_destroy();
+ }
+ });
+ }
+ break;
+ case WebSocketReadState::Mask:
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+#ifdef CROW_ENABLE_DEBUG
+ if (!ec && bytes_transferred != 4)
+ {
+ throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?");
+ }
+#endif
+
+ if (!ec)
+ {
+ state_ = WebSocketReadState::Payload;
+ do_read();
+ }
+ else
+ {
+ close_connection_ = true;
+ if (error_handler_)
+ error_handler_(*this);
+ adaptor_.close();
+ }
+ });
+ break;
+ case WebSocketReadState::Payload:
+ {
+ size_t to_read = buffer_.size();
+ if (remaining_length_ < to_read)
+ to_read = remaining_length_;
+ adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read),
+ [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
+ {
+ is_reading = false;
+
+ if (!ec)
+ {
+ fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred);
+ remaining_length_ -= bytes_transferred;
+ if (remaining_length_ == 0)
+ {
+ handle_fragment();
+ state_ = WebSocketReadState::MiniHeader;
+ do_read();
+ }
+ }
+ else
+ {
+ close_connection_ = true;
+ if (error_handler_)
+ error_handler_(*this);
+ adaptor_.close();
+ }
+ });
+ }
+ break;
+ }
+ }
+
+ bool is_FIN()
+ {
+ return mini_header_ & 0x8000;
+ }
+
+ int opcode()
+ {
+ return (mini_header_ & 0x0f00) >> 8;
+ }
+
+ void handle_fragment()
+ {
+ for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++)
+ {
+ fragment_[i] ^= ((char*)&mask_)[i%4];
+ }
+ switch(opcode())
+ {
+ case 0: // Continuation
+ {
+ message_ += fragment_;
+ if (is_FIN())
+ {
+ if (message_handler_)
+ message_handler_(*this, message_, is_binary_);
+ message_.clear();
+ }
+ }
+ case 1: // Text
+ {
+ is_binary_ = false;
+ message_ += fragment_;
+ if (is_FIN())
+ {
+ if (message_handler_)
+ message_handler_(*this, message_, is_binary_);
+ message_.clear();
+ }
+ }
+ break;
+ case 2: // Binary
+ {
+ is_binary_ = true;
+ message_ += fragment_;
+ if (is_FIN())
+ {
+ if (message_handler_)
+ message_handler_(*this, message_, is_binary_);
+ message_.clear();
+ }
+ }
+ break;
+ case 0x8: // Close
+ {
+ has_recv_close_ = true;
+ if (!has_sent_close_)
+ {
+ close(fragment_);
+ }
+ else
+ {
+ adaptor_.close();
+ close_connection_ = true;
+ if (!is_close_handler_called_)
+ {
+ if (close_handler_)
+ close_handler_(*this, fragment_);
+ is_close_handler_called_ = true;
+ }
+ check_destroy();
+ }
+ }
+ break;
+ case 0x9: // Ping
+ {
+ send_pong(fragment_);
+ }
+ break;
+ case 0xA: // Pong
+ {
+ pong_received_ = true;
+ }
+ break;
+ }
+
+ fragment_.clear();
+ }
+
+ void do_write()
+ {
+ if (sending_buffers_.empty())
+ {
+ sending_buffers_.swap(write_buffers_);
+ std::vector<boost::asio::const_buffer> buffers;
+ buffers.reserve(sending_buffers_.size());
+ for(auto& s:sending_buffers_)
+ {
+ buffers.emplace_back(boost::asio::buffer(s));
+ }
+ boost::asio::async_write(adaptor_.socket(), buffers,
+ [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/)
+ {
+ sending_buffers_.clear();
+ if (!ec && !close_connection_)
+ {
+ if (!write_buffers_.empty())
+ do_write();
+ if (has_sent_close_)
+ close_connection_ = true;
+ }
+ else
+ {
+ close_connection_ = true;
+ check_destroy();
+ }
+ });
+ }
+ }
+
+ void check_destroy()
+ {
+ //if (has_sent_close_ && has_recv_close_)
+ if (!is_close_handler_called_)
+ if (close_handler_)
+ close_handler_(*this, "uncleanly");
+ if (sending_buffers_.empty() && !is_reading)
+ delete this;
+ }
+ private:
+ Adaptor adaptor_;
+
+ std::vector<std::string> sending_buffers_;
+ std::vector<std::string> write_buffers_;
+
+ boost::array<char, 4096> buffer_;
+ bool is_binary_;
+ std::string message_;
+ std::string fragment_;
+ WebSocketReadState state_{WebSocketReadState::MiniHeader};
+ uint64_t remaining_length_{0};
+ bool close_connection_{false};
+ bool is_reading{false};
+ uint32_t mask_;
+ uint16_t mini_header_;
+ bool has_sent_close_{false};
+ bool has_recv_close_{false};
+ bool error_occured_{false};
+ bool pong_received_{false};
+ bool is_close_handler_called_{false};
+
+ std::function<void(crow::websocket::connection&)> open_handler_;
+ std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
+ std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
+ std::function<void(crow::websocket::connection&)> error_handler_;
+ };
+ }
+}