diff options
author | ipknHama <ipknhama@gmail.com> | 2014-09-11 06:32:41 +0900 |
---|---|---|
committer | ipknHama <ipknhama@gmail.com> | 2014-09-11 06:32:41 +0900 |
commit | 9eb96b7f4c2e134e768555de2650823a863843ce (patch) | |
tree | 87797876b1a8bdf12f11d31e2fa28fc462075610 | |
parent | ab1063c046b363a37ccaf91c7dfb1fecd279be36 (diff) | |
download | crow-9eb96b7f4c2e134e768555de2650823a863843ce.tar.gz crow-9eb96b7f4c2e134e768555de2650823a863843ce.zip |
Implement example CookieParser middleware and test
-rw-r--r-- | amalgamate/crow_all.h | 1481 | ||||
-rw-r--r-- | include/ci_map.h | 32 | ||||
-rw-r--r-- | include/http_connection.h | 6 | ||||
-rw-r--r-- | include/http_request.h | 24 | ||||
-rw-r--r-- | include/http_response.h | 22 | ||||
-rw-r--r-- | include/middleware.h | 133 | ||||
-rw-r--r-- | include/parser.h | 4 | ||||
-rw-r--r-- | tests/unittest.cpp | 45 |
8 files changed, 1169 insertions, 578 deletions
diff --git a/amalgamate/crow_all.h b/amalgamate/crow_all.h index edcfde8..de3b1c9 100644 --- a/amalgamate/crow_all.h +++ b/amalgamate/crow_all.h @@ -1942,96 +1942,6 @@ namespace crow #pragma once -#include <string> -#include <unordered_map> - - - -namespace crow -{ - template <typename T> - class Connection; - struct response - { - template <typename T> - friend class crow::Connection; - - std::string body; - json::wvalue json_value; - int code{200}; - std::unordered_map<std::string, std::string> headers; - - response() {} - explicit response(int code) : code(code) {} - response(std::string body) : body(std::move(body)) {} - response(json::wvalue&& json_value) : json_value(std::move(json_value)) {} - response(const json::wvalue& json_value) : body(json::dump(json_value)) {} - response(int code, std::string body) : body(std::move(body)), code(code) {} - - response(response&& r) - { - *this = std::move(r); - } - - response& operator = (const response& r) = delete; - - response& operator = (response&& r) - { - body = std::move(r.body); - json_value = std::move(r.json_value); - code = r.code; - headers = std::move(r.headers); - completed_ = r.completed_; - return *this; - } - - void clear() - { - body.clear(); - json_value.clear(); - code = 200; - headers.clear(); - completed_ = false; - } - - void write(const std::string& body_part) - { - body += body_part; - } - - void end() - { - if (!completed_) - { - completed_ = true; - if (complete_request_handler_) - { - complete_request_handler_(); - } - } - } - - void end(const std::string& body_part) - { - body += body_part; - end(); - } - - bool is_alive() - { - return is_alive_helper_ && is_alive_helper_(); - } - - private: - bool completed_{}; - std::function<void()> complete_request_handler_; - std::function<bool()> is_alive_helper_; - }; -} - - - -#pragma once #include <boost/asio.hpp> #include <deque> @@ -4764,6 +4674,41 @@ http_parser_version(void) { #pragma once +#include <boost/functional/hash.hpp> + +namespace crow +{ + struct ci_hash + { + size_t operator()(const std::string& key) const + { + std::size_t seed = 0; + std::locale locale; + + for(auto c : key) + { + boost::hash_combine(seed, std::toupper(c, locale)); + } + + return seed; + } + }; + + struct ci_key_eq + { + bool operator()(const std::string& l, const std::string& r) const + { + return boost::iequals(l, r); + } + }; + + using ci_map = std::unordered_multimap<std::string, std::string, ci_hash, ci_key_eq>; +} + + + +#pragma once + #include <string> #include <boost/date_time/local_time/local_time.hpp> #include <boost/filesystem.hpp> @@ -5252,12 +5197,14 @@ template <typename F, typename Set> struct pop_back_helper<seq<N...>, Tuple> { template <template <typename ... Args> class U> - using rebind = U<std::tuple_element<N, Tuple>...>; + using rebind = U<typename std::tuple_element<N, Tuple>::type...>; }; template <typename ... T> - struct pop_back : public pop_back_helper<typename gen_seq<sizeof...(T)-1>::type, std::tuple<T...>> + struct pop_back //: public pop_back_helper<typename gen_seq<sizeof...(T)-1>::type, std::tuple<T...>> { + template <template <typename ... Args> class U> + using rebind = typename pop_back_helper<typename gen_seq<sizeof...(T)-1>::type, std::tuple<T...>>::template rebind<U>; }; template <> @@ -5280,6 +5227,11 @@ template <typename F, typename Set> template < typename Tp > struct contains<Tp> : std::false_type {}; + + template <typename T> + struct empty_context + { + }; } } @@ -5417,72 +5369,40 @@ constexpr crow::HTTPMethod operator "" _method(const char* str, size_t len) + + namespace crow { + template <typename T> + inline const std::string& get_header_value(const T& headers, const std::string& key) + { + if (headers.count(key)) + { + return headers.find(key)->second; + } + static std::string empty; + return empty; + } + struct request { HTTPMethod method; std::string url; - std::unordered_map<std::string, std::string> headers; + ci_map headers; std::string body; - void* middleware_context; - }; -} - - - -#pragma once - - - - - -namespace crow -{ - class CookieParser - { - struct context + void add_header(std::string key, std::string value) { - std::unordered_map<std::string, std::string> jar; - }; - - template <typename AllContext> - void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) - { - // ctx == all_ctx.bind<CookieParser>() - // ctx.jar[] = ; + headers.emplace(std::move(key), std::move(value)); } - template <typename AllContext> - void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + const std::string& get_header_value(const std::string& key) { + return crow::get_header_value(headers, key); } - } - /* - App<CookieParser, AnotherJarMW> app; - A B C - A::context - int aa; - - ctx1 : public A::context - ctx2 : public ctx1, public B::context - ctx3 : public ctx2, public C::context - - C depends on A - - C::handle - context.aaa - - App::context : private CookieParser::contetx, ... - { - jar - - } - - SimpleApp - */ + void* middleware_context{}; + }; } @@ -5521,7 +5441,6 @@ namespace crow case 0: if (!self->header_value.empty()) { - boost::algorithm::to_lower(self->header_field); self->headers.emplace(std::move(self->header_field), std::move(self->header_value)); } self->header_field.assign(at, at+length); @@ -5553,7 +5472,6 @@ namespace crow HTTPParser* self = static_cast<HTTPParser*>(self_); if (!self->header_field.empty()) { - boost::algorithm::to_lower(self->header_field); self->headers.emplace(std::move(self->header_field), std::move(self->header_value)); } self->process_header(); @@ -5634,7 +5552,7 @@ namespace crow int header_building_state = 0; std::string header_field; std::string header_value; - std::unordered_map<std::string, std::string> headers; + ci_map headers; std::string body; Handler* handler_; @@ -5644,15 +5562,8 @@ namespace crow #pragma once -#include <boost/asio.hpp> -#include <boost/algorithm/string/predicate.hpp> -#include <boost/lexical_cast.hpp> -#include <atomic> -#include <chrono> -#include <array> - - - +#include <string> +#include <unordered_map> @@ -5660,474 +5571,288 @@ namespace crow +namespace crow +{ + template <typename Handler, typename ... Middlewares> + class Connection; + struct response + { + template <typename Handler, typename ... Middlewares> + friend class crow::Connection; + std::string body; + json::wvalue json_value; + int code{200}; + // `headers' stores HTTP headers. + ci_map headers; + void set_header(std::string key, std::string value) + { + headers.erase(key); + headers.emplace(std::move(key), std::move(value)); + } + void add_header(std::string key, std::string value) + { + headers.emplace(std::move(key), std::move(value)); + } + const std::string& get_header_value(const std::string& key) + { + return crow::get_header_value(headers, key); + } + response() {} + explicit response(int code) : code(code) {} + response(std::string body) : body(std::move(body)) {} + response(json::wvalue&& json_value) : json_value(std::move(json_value)) {} + response(const json::wvalue& json_value) : body(json::dump(json_value)) {} + response(int code, std::string body) : body(std::move(body)), code(code) {} -namespace crow -{ - using namespace boost; - using tcp = asio::ip::tcp; -#ifdef CROW_ENABLE_DEBUG - static int connectionCount; -#endif - template <typename Handler> - class Connection - { - public: - Connection(boost::asio::io_service& io_service, Handler* handler, const std::string& server_name) - : socket_(io_service), - handler_(handler), - parser_(this), - server_name_(server_name) + response(response&& r) { -#ifdef CROW_ENABLE_DEBUG - connectionCount ++; - CROW_LOG_DEBUG << "Connection open, total " << connectionCount << ", " << this; -#endif + *this = std::move(r); } - - ~Connection() + + response& operator = (const response& r) = delete; + + response& operator = (response&& r) noexcept { - res.complete_request_handler_ = nullptr; - cancel_deadline_timer(); -#ifdef CROW_ENABLE_DEBUG - connectionCount --; - CROW_LOG_DEBUG << "Connection closed, total " << connectionCount << ", " << this; -#endif + body = std::move(r.body); + json_value = std::move(r.json_value); + code = r.code; + headers = std::move(r.headers); + completed_ = r.completed_; + return *this; } - tcp::socket& socket() + bool is_completed() const noexcept { - return socket_; + return completed_; } - void start() + void clear() { - //auto self = this->shared_from_this(); - start_deadline(); - - do_read(); + body.clear(); + json_value.clear(); + code = 200; + headers.clear(); + completed_ = false; } - void handle_header() + void write(const std::string& body_part) { - // HTTP 1.1 Expect: 100-continue - if (parser_.check_version(1, 1) && parser_.headers.count("expect") && parser_.headers["expect"] == "100-continue") - { - buffers_.clear(); - static std::string expect_100_continue = "HTTP/1.1 100 Continue\r\n\r\n"; - buffers_.emplace_back(expect_100_continue.data(), expect_100_continue.size()); - do_write(); - } + body += body_part; } - void handle() + void end() { - cancel_deadline_timer(); - bool is_invalid_request = false; - - request req = parser_.to_request(); - if (parser_.check_version(1, 0)) - { - // HTTP/1.0 - if (!(req.headers.count("connection") && boost::iequals(req.headers["connection"],"Keep-Alive"))) - close_connection_ = true; - } - else if (parser_.check_version(1, 1)) + if (!completed_) { - // HTTP/1.1 - if (req.headers.count("connection") && req.headers["connection"] == "close") - close_connection_ = true; - if (!req.headers.count("host")) + if (complete_request_handler_) { - is_invalid_request = true; - res = response(400); + complete_request_handler_(); } + completed_ = true; } - - CROW_LOG_INFO << "Request: " << boost::lexical_cast<std::string>(socket_.remote_endpoint()) << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' ' - << method_name(req.method) << " " << req.url; - - - if (!is_invalid_request) - { - res.complete_request_handler_ = [this]{ this->complete_request(); }; - res.is_alive_helper_ = [this]()->bool{ return socket_.is_open(); }; - handler_->handle(req, res); - } - else - { - complete_request(); - } } - void complete_request() + void end(const std::string& body_part) { - CROW_LOG_INFO << "Response: " << this << ' ' << res.code << ' ' << close_connection_; - - //auto self = this->shared_from_this(); - res.complete_request_handler_ = nullptr; - - if (!socket_.is_open()) - { - //CROW_LOG_DEBUG << this << " delete (socket is closed) " << is_reading << ' ' << is_writing; - //delete this; - return; - } + body += body_part; + end(); + } - static std::unordered_map<int, std::string> statusCodes = { - {200, "HTTP/1.1 200 OK\r\n"}, - {201, "HTTP/1.1 201 Created\r\n"}, - {202, "HTTP/1.1 202 Accepted\r\n"}, - {204, "HTTP/1.1 204 No Content\r\n"}, + bool is_alive() + { + return is_alive_helper_ && is_alive_helper_(); + } - {300, "HTTP/1.1 300 Multiple Choices\r\n"}, - {301, "HTTP/1.1 301 Moved Permanently\r\n"}, - {302, "HTTP/1.1 302 Moved Temporarily\r\n"}, - {304, "HTTP/1.1 304 Not Modified\r\n"}, + private: + bool completed_{}; + std::function<void()> complete_request_handler_; + std::function<bool()> is_alive_helper_; + }; +} - {400, "HTTP/1.1 400 Bad Request\r\n"}, - {401, "HTTP/1.1 401 Unauthorized\r\n"}, - {403, "HTTP/1.1 403 Forbidden\r\n"}, - {404, "HTTP/1.1 404 Not Found\r\n"}, - {500, "HTTP/1.1 500 Internal Server Error\r\n"}, - {501, "HTTP/1.1 501 Not Implemented\r\n"}, - {502, "HTTP/1.1 502 Bad Gateway\r\n"}, - {503, "HTTP/1.1 503 Service Unavailable\r\n"}, - }; - static std::string seperator = ": "; - static std::string crlf = "\r\n"; +#pragma once +#include <boost/algorithm/string/trim.hpp> - buffers_.clear(); - buffers_.reserve(4*(res.headers.size()+4)+3); - if (res.body.empty() && res.json_value.t() == json::type::Object) - { - res.body = json::dump(res.json_value); - } - if (!statusCodes.count(res.code)) - res.code = 500; - { - auto& status = statusCodes.find(res.code)->second; - buffers_.emplace_back(status.data(), status.size()); - } - if (res.code >= 400 && res.body.empty()) - res.body = statusCodes[res.code].substr(9); - bool has_content_length = false; - bool has_date = false; - bool has_server = false; +namespace crow +{ + // Any middleware requires following 3 members: + + // struct context; + // storing data for the middleware; can be read from another middleware or handlers + + // before_handle + // called before handling the request. + // if res.end() is called, the operation is halted. + // (still call after_handle of this middleware) + // 2 signatures: + // void before_handle(request& req, response& res, context& ctx) + // if you only need to access this middlewares context. + // template <typename AllContext> + // void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + // you can access another middlewares' context by calling `all_ctx.template get<MW>()' + // ctx == all_ctx.template get<CurrentMiddleware>() + + // after_handle + // called after handling the request. + // void after_handle(request& req, response& res, context& ctx) + // template <typename AllContext> + // void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + + struct CookieParser + { + struct context + { + std::unordered_map<std::string, std::string> jar; + std::unordered_map<std::string, std::string> cookies_to_add; - for(auto& kv : res.headers) + std::string get_cookie(const std::string& key) { - buffers_.emplace_back(kv.first.data(), kv.first.size()); - buffers_.emplace_back(seperator.data(), seperator.size()); - buffers_.emplace_back(kv.second.data(), kv.second.size()); - buffers_.emplace_back(crlf.data(), crlf.size()); - - if (boost::iequals(kv.first, "content-length")) - has_content_length = true; - if (boost::iequals(kv.first, "date")) - has_date = true; - if (boost::iequals(kv.first, "server")) - has_server = true; + if (jar.count(key)) + return jar[key]; + return {}; } - if (!has_content_length) + void set_cookie(const std::string& key, const std::string& value) { - content_length_ = std::to_string(res.body.size()); - static std::string content_length_tag = "Content-Length: "; - buffers_.emplace_back(content_length_tag.data(), content_length_tag.size()); - buffers_.emplace_back(content_length_.data(), content_length_.size()); - buffers_.emplace_back(crlf.data(), crlf.size()); + cookies_to_add.emplace(key, value); } - if (!has_server) - { - static std::string server_tag = "Server: "; - buffers_.emplace_back(server_tag.data(), server_tag.size()); - buffers_.emplace_back(server_name_.data(), server_name_.size()); - buffers_.emplace_back(crlf.data(), crlf.size()); - } - if (!has_date) - { - static std::string date_tag = "Date: "; - date_str_ = get_cached_date_str(); - buffers_.emplace_back(date_tag.data(), date_tag.size()); - buffers_.emplace_back(date_str_.data(), date_str_.size()); - buffers_.emplace_back(crlf.data(), crlf.size()); - } - - buffers_.emplace_back(crlf.data(), crlf.size()); - buffers_.emplace_back(res.body.data(), res.body.size()); - - do_write(); - res.clear(); - } + }; - private: - static std::string get_cached_date_str() + void before_handle(request& req, response& res, context& ctx) { - using namespace std::chrono; - thread_local auto last = steady_clock::now(); - thread_local std::string date_str = DateTime().str(); - - if (steady_clock::now() - last >= seconds(1)) + int count = req.headers.count("Cookie"); + if (!count) + return; + if (count > 1) { - last = steady_clock::now(); - date_str = DateTime().str(); + res.code = 400; + res.end(); + return; } - return date_str; - } + std::string cookies = req.get_header_value("Cookie"); + size_t pos = 0; + while(pos < cookies.size()) + { + size_t pos_equal = cookies.find('=', pos); + if (pos_equal == cookies.npos) + break; + std::string name = cookies.substr(pos, pos_equal-pos); + boost::trim(name); + pos = pos_equal+1; + while(pos < cookies.size() && cookies[pos] == ' ') pos++; + if (pos == cookies.size()) + break; - void do_read() - { - //auto self = this->shared_from_this(); - is_reading = true; - socket_.async_read_some(boost::asio::buffer(buffer_), - [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + std::string value; + + if (cookies[pos] == '"') { - bool error_while_reading = true; - if (!ec) + int dquote_meet_count = 0; + pos ++; + size_t pos_dquote = pos-1; + do { - bool ret = parser_.feed(buffer_.data(), bytes_transferred); - if (ret && socket_.is_open() && !close_connection_) - { - error_while_reading = false; - } - } + pos_dquote = cookies.find('"', pos_dquote+1); + dquote_meet_count ++; + } while(pos_dquote < cookies.size() && cookies[pos_dquote-1] == '\\'); + if (pos_dquote == cookies.npos) + break; - if (error_while_reading) - { - cancel_deadline_timer(); - parser_.done(); - socket_.close(); - is_reading = false; - CROW_LOG_DEBUG << this << " from read(1)"; - check_destroy(); - } + if (dquote_meet_count == 1) + value = cookies.substr(pos, pos_dquote - pos); else { - start_deadline(); - do_read(); - } - }); - } - - void do_write() - { - //auto self = this->shared_from_this(); - is_writing = true; - boost::asio::async_write(socket_, buffers_, - [&](const boost::system::error_code& ec, std::size_t bytes_transferred) - { - is_writing = false; - if (!ec) - { - if (close_connection_) + value.clear(); + value.reserve(pos_dquote-pos); + for(size_t p = pos; p < pos_dquote; p++) { - socket_.close(); - CROW_LOG_DEBUG << this << " from write(1)"; - check_destroy(); + // FIXME minimal escaping + if (cookies[p] == '\\' && p + 1 < pos_dquote) + { + p++; + if (cookies[p] == '\\' || cookies[p] == '"') + value += cookies[p]; + else + { + value += '\\'; + value += cookies[p]; + } + } + else + value += cookies[p]; } } - else - { - CROW_LOG_DEBUG << this << " from write(2)"; - check_destroy(); - } - }); - } - void check_destroy() - { - CROW_LOG_DEBUG << this << " is_reading " << is_reading << " is_writing " << is_writing; - if (!is_reading && !is_writing) - { - CROW_LOG_DEBUG << this << " delete (idle) "; - delete this; + ctx.jar.emplace(std::move(name), std::move(value)); + pos = cookies.find(";", pos_dquote+1); + if (pos == cookies.npos) + break; + pos++; + while(pos < cookies.size() && cookies[pos] == ' ') pos++; + if (pos == cookies.size()) + break; + } + else + { + size_t pos_semicolon = cookies.find(';', pos); + value = cookies.substr(pos, pos_semicolon - pos); + boost::trim(value); + ctx.jar.emplace(std::move(name), std::move(value)); + pos = pos_semicolon; + if (pos == cookies.npos) + break; + pos ++; + while(pos < cookies.size() && cookies[pos] == ' ') pos++; + if (pos == cookies.size()) + break; + } } } - void cancel_deadline_timer() - { - CROW_LOG_DEBUG << this << " timer cancelled: " << timer_cancel_key_.first << ' ' << timer_cancel_key_.second; - detail::dumb_timer_queue::get_current_dumb_timer_queue().cancel(timer_cancel_key_); - } - - void start_deadline(int timeout = 5) + void after_handle(request& req, response& res, context& ctx) { - auto& timer_queue = detail::dumb_timer_queue::get_current_dumb_timer_queue(); - cancel_deadline_timer(); - - timer_cancel_key_ = timer_queue.add([this] + for(auto& cookie:ctx.cookies_to_add) { - if (!socket_.is_open()) - { - return; - } - socket_.close(); - }); - CROW_LOG_DEBUG << this << " timer added: " << timer_cancel_key_.first << ' ' << timer_cancel_key_.second; + res.add_header("Set-Cookie", cookie.first + "=" + cookie.second); + } } - - private: - tcp::socket socket_; - Handler* handler_; - - std::array<char, 4096> buffer_; - - HTTPParser<Connection> parser_; - response res; - - bool close_connection_ = false; - - const std::string& server_name_; - std::vector<boost::asio::const_buffer> buffers_; - - std::string content_length_; - std::string date_str_; - - //boost::asio::deadline_timer deadline_; - detail::dumb_timer_queue::key timer_cancel_key_; - - bool is_reading{}; - bool is_writing{}; }; -} - - - -#pragma once - -#include <boost/date_time/posix_time/posix_time.hpp> -#include <boost/asio.hpp> -#include <cstdint> -#include <atomic> -#include <future> - -#include <memory> - - - - - - + /* + App<CookieParser, AnotherJarMW> app; + A B C + A::context + int aa; + ctx1 : public A::context + ctx2 : public ctx1, public B::context + ctx3 : public ctx2, public C::context + C depends on A + C::handle + context.aaa -namespace crow -{ - using namespace boost; - using tcp = asio::ip::tcp; - - template <typename Handler> - class Server + App::context : private CookieParser::contetx, ... { - public: - Server(Handler* handler, uint16_t port, uint16_t concurrency = 1) - : acceptor_(io_service_, tcp::endpoint(asio::ip::address(), port)), - signals_(io_service_, SIGINT, SIGTERM), - handler_(handler), - concurrency_(concurrency), - port_(port) - { - } - - void run() - { - if (concurrency_ < 0) - concurrency_ = 1; - - for(int i = 0; i < concurrency_; i++) - io_service_pool_.emplace_back(new boost::asio::io_service()); - - std::vector<std::future<void>> v; - for(uint16_t i = 0; i < concurrency_; i ++) - v.push_back( - std::async(std::launch::async, [this, i]{ - auto& timer_queue = detail::dumb_timer_queue::get_current_dumb_timer_queue(); - timer_queue.set_io_service(*io_service_pool_[i]); - boost::asio::deadline_timer timer(*io_service_pool_[i]); - timer.expires_from_now(boost::posix_time::seconds(1)); - std::function<void(const boost::system::error_code& ec)> handler; - handler = [&](const boost::system::error_code& ec){ - if (ec) - return; - timer_queue.process(); - timer.expires_from_now(boost::posix_time::seconds(1)); - timer.async_wait(handler); - }; - timer.async_wait(handler); - io_service_pool_[i]->run(); - })); - CROW_LOG_INFO << server_name_ << " server is running, local port " << port_; - - signals_.async_wait( - [&](const boost::system::error_code& error, int signal_number){ - stop(); - }); - - do_accept(); - - v.push_back(std::async(std::launch::async, [this]{ - io_service_.run(); - CROW_LOG_INFO << "Exiting."; - })); - } - - void stop() - { - io_service_.stop(); - for(auto& io_service:io_service_pool_) - io_service->stop(); - } - - private: - asio::io_service& pick_io_service() - { - // TODO load balancing - roundrobin_index_++; - if (roundrobin_index_ >= io_service_pool_.size()) - roundrobin_index_ = 0; - return *io_service_pool_[roundrobin_index_]; - } - - void do_accept() - { - auto p = new Connection<Handler>(pick_io_service(), handler_, server_name_); - acceptor_.async_accept(p->socket(), - [this, p](boost::system::error_code ec) - { - if (!ec) - { - p->start(); - } - do_accept(); - }); - } + jar - private: - asio::io_service io_service_; - std::vector<std::unique_ptr<asio::io_service>> io_service_pool_; - tcp::acceptor acceptor_; - boost::asio::signal_set signals_; + } - Handler* handler_; - uint16_t concurrency_{1}; - std::string server_name_ = "Crow/0.1"; - uint16_t port_; - unsigned int roundrobin_index_{}; - }; + SimpleApp + */ } @@ -6809,22 +6534,47 @@ namespace crow : public black_magic::pop_back<Middlewares...>::template rebind<partial_context> , public black_magic::last_element_type<Middlewares...>::type::context { + using parent_context = typename black_magic::pop_back<Middlewares...>::template rebind<::crow::detail::partial_context>; + template <int N> + using partial = typename std::conditional<N == sizeof...(Middlewares)-1, partial_context, typename parent_context::template partial<N>>::type; + + template <typename T> + typename T::context& get() + { + return static_cast<typename T::context&>(*this); + } }; template <> struct partial_context<> { + template <int> + using partial = partial_context; }; + template <int N, typename Context, typename Container, typename CurrentMW, typename ... Middlewares> + bool middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx); + template <typename ... Middlewares> struct context : private partial_context<Middlewares...> //struct context : private Middlewares::context... // simple but less type-safe { + template <int N, typename Context, typename Container> + friend typename std::enable_if<(N==0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res); + template <int N, typename Context, typename Container> + friend typename std::enable_if<(N>0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res); + + template <int N, typename Context, typename Container, typename CurrentMW, typename ... Middlewares2> + friend bool middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx); + template <typename T> typename T::context& get() { return static_cast<typename T::context&>(*this); } + + template <int N> + using partial = typename partial_context<Middlewares...>::template partial<N>; }; } } @@ -6832,6 +6582,626 @@ namespace crow #pragma once +#include <boost/asio.hpp> +#include <boost/algorithm/string/predicate.hpp> +#include <boost/lexical_cast.hpp> +#include <boost/array.hpp> +#include <atomic> +#include <chrono> + + + + + + + + + + + + + + + + + + + +namespace crow +{ + namespace detail + { + template <typename MW, typename Context, typename ParentContext> + void before_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>()))* dummy = 0) + { + mw.before_handle(req, res, ctx.template get<MW>()); + } + + template <typename MW, typename Context, typename ParentContext> + void before_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>(), std::declval<Context&>))* dummy = 0) + { + mw.before_handle(req, res, ctx.template get<MW>(), parent_ctx); + } + + template <typename MW, typename Context, typename ParentContext> + void after_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>()))* dummy = 0) + { + mw.after_handle(req, res, ctx.template get<MW>()); + } + + template <typename MW, typename Context, typename ParentContext> + void after_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>(), std::declval<Context&>))* dummy = 0) + { + mw.after_handle(req, res, ctx.template get<MW>(), parent_ctx); + } + + template <int N, typename Context, typename Container, typename CurrentMW, typename ... Middlewares> + bool middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx) + { + using parent_context_t = typename Context::template partial<N-1>; + using current_context_t = typename Context::template partial<N>; + before_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + + if (res.is_completed()) + { + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + return true; + } + + if (middleware_call_helper<N+1, Context, Container, Middlewares...>(middlewares, req, res, ctx)) + { + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + return true; + } + + return false; + } + + template <int N, typename Context, typename Container> + bool middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx) + { + return false; + } + + template <int N, typename Context, typename Container> + typename std::enable_if<(N<0)>::type + after_handlers_call_helper(Container& middlewares, Context& context, request& req, response& res) + { + } + + template <int N, typename Context, typename Container> + typename std::enable_if<(N==0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res) + { + using parent_context_t = typename Context::template partial<N-1>; + using current_context_t = typename Context::template partial<N>; + using CurrentMW = typename std::tuple_element<N, typename std::remove_reference<Container>::type>::type; + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + } + + template <int N, typename Context, typename Container> + typename std::enable_if<(N>0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res) + { + using parent_context_t = typename Context::template partial<N-1>; + using current_context_t = typename Context::template partial<N>; + using CurrentMW = typename std::tuple_element<N, typename std::remove_reference<Container>::type>::type; + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + after_handlers_call_helper<N-1, Context, Container>(middlewares, ctx, req, res); + } + } + + using namespace boost; + using tcp = asio::ip::tcp; +#ifdef CROW_ENABLE_DEBUG + static int connectionCount; +#endif + template <typename Handler, typename ... Middlewares> + class Connection + { + public: + Connection( + boost::asio::io_service& io_service, + Handler* handler, + const std::string& server_name, + std::tuple<Middlewares...>& middlewares + ) + : socket_(io_service), + handler_(handler), + parser_(this), + server_name_(server_name), + middlewares_(middlewares) + { +#ifdef CROW_ENABLE_DEBUG + connectionCount ++; + CROW_LOG_DEBUG << "Connection open, total " << connectionCount << ", " << this; +#endif + } + + ~Connection() + { + res.complete_request_handler_ = nullptr; + cancel_deadline_timer(); +#ifdef CROW_ENABLE_DEBUG + connectionCount --; + CROW_LOG_DEBUG << "Connection closed, total " << connectionCount << ", " << this; +#endif + } + + tcp::socket& socket() + { + return socket_; + } + + void start() + { + //auto self = this->shared_from_this(); + start_deadline(); + + do_read(); + } + + void handle_header() + { + // HTTP 1.1 Expect: 100-continue + if (parser_.check_version(1, 1) && parser_.headers.count("expect") && get_header_value(parser_.headers, "expect") == "100-continue") + { + buffers_.clear(); + static std::string expect_100_continue = "HTTP/1.1 100 Continue\r\n\r\n"; + buffers_.emplace_back(expect_100_continue.data(), expect_100_continue.size()); + do_write(); + } + } + + void handle() + { + cancel_deadline_timer(); + bool is_invalid_request = false; + + req_ = std::move(parser_.to_request()); + request& req = req_; + if (parser_.check_version(1, 0)) + { + // HTTP/1.0 + if (!(req.headers.count("connection") && boost::iequals(req.get_header_value("connection"),"Keep-Alive"))) + close_connection_ = true; + } + else if (parser_.check_version(1, 1)) + { + // HTTP/1.1 + if (req.headers.count("connection") && req.get_header_value("connection") == "close") + close_connection_ = true; + if (!req.headers.count("host")) + { + is_invalid_request = true; + res = response(400); + } + } + + CROW_LOG_INFO << "Request: " << boost::lexical_cast<std::string>(socket_.remote_endpoint()) << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' ' + << method_name(req.method) << " " << req.url; + + + need_to_call_after_handlers_ = false; + if (!is_invalid_request) + { + res.complete_request_handler_ = []{}; + res.is_alive_helper_ = [this]()->bool{ return socket_.is_open(); }; + + ctx_ = detail::context<Middlewares...>(); + req.middleware_context = (void*)&ctx_; + detail::middleware_call_helper<0, decltype(ctx_), decltype(middlewares_), Middlewares...>(middlewares_, req, res, ctx_); + CROW_LOG_DEBUG << "ALATDA " << req.url; + + if (!res.completed_) + { + res.complete_request_handler_ = [this]{ this->complete_request(); }; + need_to_call_after_handlers_ = true; + handler_->handle(req, res); + } + else + { + complete_request(); + } + } + else + { + complete_request(); + } + } + + void complete_request() + { + CROW_LOG_INFO << "Response: " << this << ' ' << req_.url << ' ' << res.code << ' ' << close_connection_; + + if (need_to_call_after_handlers_) + { + // call all after_handler of middlewares + detail::after_handlers_call_helper< + ((int)sizeof...(Middlewares)-1), + decltype(ctx_), + decltype(middlewares_)> + (middlewares_, ctx_, req_, res); + } + + //auto self = this->shared_from_this(); + res.complete_request_handler_ = nullptr; + + if (!socket_.is_open()) + { + //CROW_LOG_DEBUG << this << " delete (socket is closed) " << is_reading << ' ' << is_writing; + //delete this; + return; + } + + static std::unordered_map<int, std::string> statusCodes = { + {200, "HTTP/1.1 200 OK\r\n"}, + {201, "HTTP/1.1 201 Created\r\n"}, + {202, "HTTP/1.1 202 Accepted\r\n"}, + {204, "HTTP/1.1 204 No Content\r\n"}, + + {300, "HTTP/1.1 300 Multiple Choices\r\n"}, + {301, "HTTP/1.1 301 Moved Permanently\r\n"}, + {302, "HTTP/1.1 302 Moved Temporarily\r\n"}, + {304, "HTTP/1.1 304 Not Modified\r\n"}, + + {400, "HTTP/1.1 400 Bad Request\r\n"}, + {401, "HTTP/1.1 401 Unauthorized\r\n"}, + {403, "HTTP/1.1 403 Forbidden\r\n"}, + {404, "HTTP/1.1 404 Not Found\r\n"}, + + {500, "HTTP/1.1 500 Internal Server Error\r\n"}, + {501, "HTTP/1.1 501 Not Implemented\r\n"}, + {502, "HTTP/1.1 502 Bad Gateway\r\n"}, + {503, "HTTP/1.1 503 Service Unavailable\r\n"}, + }; + + static std::string seperator = ": "; + static std::string crlf = "\r\n"; + + buffers_.clear(); + buffers_.reserve(4*(res.headers.size()+4)+3); + + if (res.body.empty() && res.json_value.t() == json::type::Object) + { + res.body = json::dump(res.json_value); + } + + if (!statusCodes.count(res.code)) + res.code = 500; + { + auto& status = statusCodes.find(res.code)->second; + buffers_.emplace_back(status.data(), status.size()); + } + + if (res.code >= 400 && res.body.empty()) + res.body = statusCodes[res.code].substr(9); + + bool has_content_length = false; + bool has_date = false; + bool has_server = false; + + for(auto& kv : res.headers) + { + buffers_.emplace_back(kv.first.data(), kv.first.size()); + buffers_.emplace_back(seperator.data(), seperator.size()); + buffers_.emplace_back(kv.second.data(), kv.second.size()); + buffers_.emplace_back(crlf.data(), crlf.size()); + + if (boost::iequals(kv.first, "content-length")) + has_content_length = true; + if (boost::iequals(kv.first, "date")) + has_date = true; + if (boost::iequals(kv.first, "server")) + has_server = true; + } + + if (!has_content_length) + { + content_length_ = std::to_string(res.body.size()); + static std::string content_length_tag = "Content-Length: "; + buffers_.emplace_back(content_length_tag.data(), content_length_tag.size()); + buffers_.emplace_back(content_length_.data(), content_length_.size()); + buffers_.emplace_back(crlf.data(), crlf.size()); + } + if (!has_server) + { + static std::string server_tag = "Server: "; + buffers_.emplace_back(server_tag.data(), server_tag.size()); + buffers_.emplace_back(server_name_.data(), server_name_.size()); + buffers_.emplace_back(crlf.data(), crlf.size()); + } + if (!has_date) + { + static std::string date_tag = "Date: "; + date_str_ = get_cached_date_str(); + buffers_.emplace_back(date_tag.data(), date_tag.size()); + buffers_.emplace_back(date_str_.data(), date_str_.size()); + buffers_.emplace_back(crlf.data(), crlf.size()); + } + + buffers_.emplace_back(crlf.data(), crlf.size()); + buffers_.emplace_back(res.body.data(), res.body.size()); + + do_write(); + res.clear(); + } + + private: + static std::string get_cached_date_str() + { + using namespace std::chrono; + thread_local auto last = steady_clock::now(); + thread_local std::string date_str = DateTime().str(); + + if (steady_clock::now() - last >= seconds(1)) + { + last = steady_clock::now(); + date_str = DateTime().str(); + } + return date_str; + } + + void do_read() + { + //auto self = this->shared_from_this(); + is_reading = true; + socket_.async_read_some(boost::asio::buffer(buffer_), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + bool error_while_reading = true; + if (!ec) + { + bool ret = parser_.feed(buffer_.data(), bytes_transferred); + if (ret && socket_.is_open() && !close_connection_) + { + error_while_reading = false; + } + } + + if (error_while_reading) + { + cancel_deadline_timer(); + parser_.done(); + socket_.close(); + is_reading = false; + CROW_LOG_DEBUG << this << " from read(1)"; + check_destroy(); + } + else + { + start_deadline(); + do_read(); + } + }); + } + + void do_write() + { + //auto self = this->shared_from_this(); + is_writing = true; + boost::asio::async_write(socket_, buffers_, + [&](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_writing = false; + if (!ec) + { + if (close_connection_) + { + socket_.close(); + CROW_LOG_DEBUG << this << " from write(1)"; + check_destroy(); + } + } + else + { + CROW_LOG_DEBUG << this << " from write(2)"; + check_destroy(); + } + }); + } + + void check_destroy() + { + CROW_LOG_DEBUG << this << " is_reading " << is_reading << " is_writing " << is_writing; + if (!is_reading && !is_writing) + { + CROW_LOG_DEBUG << this << " delete (idle) "; + delete this; + } + } + + void cancel_deadline_timer() + { + CROW_LOG_DEBUG << this << " timer cancelled: " << timer_cancel_key_.first << ' ' << timer_cancel_key_.second; + detail::dumb_timer_queue::get_current_dumb_timer_queue().cancel(timer_cancel_key_); + } + + void start_deadline(int timeout = 5) + { + auto& timer_queue = detail::dumb_timer_queue::get_current_dumb_timer_queue(); + cancel_deadline_timer(); + + timer_cancel_key_ = timer_queue.add([this] + { + if (!socket_.is_open()) + { + return; + } + socket_.close(); + }); + CROW_LOG_DEBUG << this << " timer added: " << timer_cancel_key_.first << ' ' << timer_cancel_key_.second; + } + + private: + tcp::socket socket_; + Handler* handler_; + + boost::array<char, 4096> buffer_; + + HTTPParser<Connection> parser_; + request req_; + response res; + + bool close_connection_ = false; + + const std::string& server_name_; + std::vector<boost::asio::const_buffer> buffers_; + + std::string content_length_; + std::string date_str_; + + //boost::asio::deadline_timer deadline_; + detail::dumb_timer_queue::key timer_cancel_key_; + + bool is_reading{}; + bool is_writing{}; + bool need_to_call_after_handlers_; + + std::tuple<Middlewares...>& middlewares_; + detail::context<Middlewares...> ctx_; + }; + +} + + + +#pragma once + +#include <boost/date_time/posix_time/posix_time.hpp> +#include <boost/asio.hpp> +#include <cstdint> +#include <atomic> +#include <future> + +#include <memory> + + + + + + + + + + +namespace crow +{ + using namespace boost; + using tcp = asio::ip::tcp; + + template <typename Handler, typename ... Middlewares> + class Server + { + public: + Server(Handler* handler, uint16_t port, uint16_t concurrency = 1) + : acceptor_(io_service_, tcp::endpoint(asio::ip::address(), port)), + signals_(io_service_, SIGINT, SIGTERM), + handler_(handler), + concurrency_(concurrency), + port_(port) + { + } + + void run() + { + if (concurrency_ < 0) + concurrency_ = 1; + + for(int i = 0; i < concurrency_; i++) + io_service_pool_.emplace_back(new boost::asio::io_service()); + + std::vector<std::future<void>> v; + for(uint16_t i = 0; i < concurrency_; i ++) + v.push_back( + std::async(std::launch::async, [this, i]{ + // initializing timer queue + auto& timer_queue = detail::dumb_timer_queue::get_current_dumb_timer_queue(); + + timer_queue.set_io_service(*io_service_pool_[i]); + boost::asio::deadline_timer timer(*io_service_pool_[i]); + timer.expires_from_now(boost::posix_time::seconds(1)); + + std::function<void(const boost::system::error_code& ec)> handler; + handler = [&](const boost::system::error_code& ec){ + if (ec) + return; + timer_queue.process(); + timer.expires_from_now(boost::posix_time::seconds(1)); + timer.async_wait(handler); + }; + timer.async_wait(handler); + + io_service_pool_[i]->run(); + })); + CROW_LOG_INFO << server_name_ << " server is running, local port " << port_; + + signals_.async_wait( + [&](const boost::system::error_code& error, int signal_number){ + stop(); + }); + + do_accept(); + + v.push_back(std::async(std::launch::async, [this]{ + io_service_.run(); + CROW_LOG_INFO << "Exiting."; + })); + } + + void stop() + { + io_service_.stop(); + for(auto& io_service:io_service_pool_) + io_service->stop(); + } + + private: + asio::io_service& pick_io_service() + { + // TODO load balancing + roundrobin_index_++; + if (roundrobin_index_ >= io_service_pool_.size()) + roundrobin_index_ = 0; + return *io_service_pool_[roundrobin_index_]; + } + + void do_accept() + { + auto p = new Connection<Handler, Middlewares...>(pick_io_service(), handler_, server_name_, middlewares_); + acceptor_.async_accept(p->socket(), + [this, p](boost::system::error_code ec) + { + if (!ec) + { + p->start(); + } + do_accept(); + }); + } + + private: + asio::io_service io_service_; + std::vector<std::unique_ptr<asio::io_service>> io_service_pool_; + tcp::acceptor acceptor_; + boost::asio::signal_set signals_; + + Handler* handler_; + uint16_t concurrency_{1}; + std::string server_name_ = "Crow/0.1"; + uint16_t port_; + unsigned int roundrobin_index_{}; + + std::tuple<Middlewares...> middlewares_; + + }; +} + + + +#pragma once #include <string> #include <functional> @@ -6865,13 +7235,14 @@ namespace crow { public: using self_t = Crow; + using server_t = Server<Crow, Middlewares...>; Crow() { } void handle(const request& req, response& res) { - return router_.handle(req, res); + router_.handle(req, res); } template <uint64_t Tag> @@ -6908,7 +7279,7 @@ namespace crow void run() { validate(); - Server<self_t> server(this, port_, concurrency_); + server_t server(this, port_, concurrency_); server.run(); } @@ -6921,19 +7292,17 @@ namespace crow // middleware using context_t = detail::context<Middlewares...>; template <typename T> - T& get_middleware_context(request& req) + typename T::context& get_context(const request& req) { static_assert(black_magic::contains<T, Middlewares...>::value, "App doesn't have the specified middleware type."); auto& ctx = *reinterpret_cast<context_t*>(req.middleware_context); - return ctx.get<T>(); + return ctx.template get<T>(); } private: uint16_t port_ = 80; uint16_t concurrency_ = 1; - std::tuple<Middlewares...> middlewares_; - Router router_; }; template <typename ... Middlewares> diff --git a/include/ci_map.h b/include/ci_map.h new file mode 100644 index 0000000..9b48f0b --- /dev/null +++ b/include/ci_map.h @@ -0,0 +1,32 @@ +#pragma once + +#include <boost/functional/hash.hpp> + +namespace crow +{ + struct ci_hash + { + size_t operator()(const std::string& key) const + { + std::size_t seed = 0; + std::locale locale; + + for(auto c : key) + { + boost::hash_combine(seed, std::toupper(c, locale)); + } + + return seed; + } + }; + + struct ci_key_eq + { + bool operator()(const std::string& l, const std::string& r) const + { + return boost::iequals(l, r); + } + }; + + using ci_map = std::unordered_multimap<std::string, std::string, ci_hash, ci_key_eq>; +} diff --git a/include/http_connection.h b/include/http_connection.h index 3a9fe65..54d0860 100644 --- a/include/http_connection.h +++ b/include/http_connection.h @@ -155,7 +155,7 @@ namespace crow void handle_header() { // HTTP 1.1 Expect: 100-continue - if (parser_.check_version(1, 1) && parser_.headers.count("expect") && parser_.headers["expect"] == "100-continue") + if (parser_.check_version(1, 1) && parser_.headers.count("expect") && get_header_value(parser_.headers, "expect") == "100-continue") { buffers_.clear(); static std::string expect_100_continue = "HTTP/1.1 100 Continue\r\n\r\n"; @@ -174,13 +174,13 @@ namespace crow if (parser_.check_version(1, 0)) { // HTTP/1.0 - if (!(req.headers.count("connection") && boost::iequals(req.headers["connection"],"Keep-Alive"))) + if (!(req.headers.count("connection") && boost::iequals(req.get_header_value("connection"),"Keep-Alive"))) close_connection_ = true; } else if (parser_.check_version(1, 1)) { // HTTP/1.1 - if (req.headers.count("connection") && req.headers["connection"] == "close") + if (req.headers.count("connection") && req.get_header_value("connection") == "close") close_connection_ = true; if (!req.headers.count("host")) { diff --git a/include/http_request.h b/include/http_request.h index af623c6..7d2da67 100644 --- a/include/http_request.h +++ b/include/http_request.h @@ -1,16 +1,38 @@ #pragma once #include "common.h" +#include "ci_map.h" namespace crow { + template <typename T> + inline const std::string& get_header_value(const T& headers, const std::string& key) + { + if (headers.count(key)) + { + return headers.find(key)->second; + } + static std::string empty; + return empty; + } + struct request { HTTPMethod method; std::string url; - std::unordered_map<std::string, std::string> headers; + ci_map headers; std::string body; + void add_header(std::string key, std::string value) + { + headers.emplace(std::move(key), std::move(value)); + } + + const std::string& get_header_value(const std::string& key) + { + return crow::get_header_value(headers, key); + } + void* middleware_context{}; }; } diff --git a/include/http_response.h b/include/http_response.h index bc468b7..69c6d72 100644 --- a/include/http_response.h +++ b/include/http_response.h @@ -2,6 +2,8 @@ #include <string> #include <unordered_map> #include "json.h" +#include "http_request.h" +#include "ci_map.h" namespace crow { @@ -15,7 +17,25 @@ namespace crow std::string body; json::wvalue json_value; int code{200}; - std::unordered_map<std::string, std::string> headers; + + // `headers' stores HTTP headers. + ci_map headers; + + void set_header(std::string key, std::string value) + { + headers.erase(key); + headers.emplace(std::move(key), std::move(value)); + } + void add_header(std::string key, std::string value) + { + headers.emplace(std::move(key), std::move(value)); + } + + const std::string& get_header_value(const std::string& key) + { + return crow::get_header_value(headers, key); + } + response() {} explicit response(int code) : code(code) {} diff --git a/include/middleware.h b/include/middleware.h index 5e358af..ec54476 100644 --- a/include/middleware.h +++ b/include/middleware.h @@ -1,4 +1,5 @@ #pragma once +#include <boost/algorithm/string/trim.hpp> #include "http_request.h" #include "http_response.h" @@ -9,35 +10,143 @@ namespace crow // struct context; // storing data for the middleware; can be read from another middleware or handlers - // template <typename AllContext> - // void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + // before_handle // called before handling the request. // if res.end() is called, the operation is halted. // (still call after_handle of this middleware) + // 2 signatures: + // void before_handle(request& req, response& res, context& ctx) + // if you only need to access this middlewares context. + // template <typename AllContext> + // void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + // you can access another middlewares' context by calling `all_ctx.template get<MW>()' + // ctx == all_ctx.template get<CurrentMiddleware>() - // template <typename AllContext> - // void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + // after_handle // called after handling the request. + // void after_handle(request& req, response& res, context& ctx) + // template <typename AllContext> + // void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) - class CookieParser + struct CookieParser { struct context { std::unordered_map<std::string, std::string> jar; + std::unordered_map<std::string, std::string> cookies_to_add; + + std::string get_cookie(const std::string& key) + { + if (jar.count(key)) + return jar[key]; + return {}; + } + + void set_cookie(const std::string& key, const std::string& value) + { + cookies_to_add.emplace(key, value); + } }; - template <typename AllContext> - void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + void before_handle(request& req, response& res, context& ctx) { - // ctx == all_ctx.bind<CookieParser>() - // ctx.jar[] = ; + int count = req.headers.count("Cookie"); + if (!count) + return; + if (count > 1) + { + res.code = 400; + res.end(); + return; + } + std::string cookies = req.get_header_value("Cookie"); + size_t pos = 0; + while(pos < cookies.size()) + { + size_t pos_equal = cookies.find('=', pos); + if (pos_equal == cookies.npos) + break; + std::string name = cookies.substr(pos, pos_equal-pos); + boost::trim(name); + pos = pos_equal+1; + while(pos < cookies.size() && cookies[pos] == ' ') pos++; + if (pos == cookies.size()) + break; + + std::string value; + + if (cookies[pos] == '"') + { + int dquote_meet_count = 0; + pos ++; + size_t pos_dquote = pos-1; + do + { + pos_dquote = cookies.find('"', pos_dquote+1); + dquote_meet_count ++; + } while(pos_dquote < cookies.size() && cookies[pos_dquote-1] == '\\'); + if (pos_dquote == cookies.npos) + break; + + if (dquote_meet_count == 1) + value = cookies.substr(pos, pos_dquote - pos); + else + { + value.clear(); + value.reserve(pos_dquote-pos); + for(size_t p = pos; p < pos_dquote; p++) + { + // FIXME minimal escaping + if (cookies[p] == '\\' && p + 1 < pos_dquote) + { + p++; + if (cookies[p] == '\\' || cookies[p] == '"') + value += cookies[p]; + else + { + value += '\\'; + value += cookies[p]; + } + } + else + value += cookies[p]; + } + } + + ctx.jar.emplace(std::move(name), std::move(value)); + pos = cookies.find(";", pos_dquote+1); + if (pos == cookies.npos) + break; + pos++; + while(pos < cookies.size() && cookies[pos] == ' ') pos++; + if (pos == cookies.size()) + break; + } + else + { + size_t pos_semicolon = cookies.find(';', pos); + value = cookies.substr(pos, pos_semicolon - pos); + boost::trim(value); + ctx.jar.emplace(std::move(name), std::move(value)); + pos = pos_semicolon; + if (pos == cookies.npos) + break; + pos ++; + while(pos < cookies.size() && cookies[pos] == ' ') pos++; + if (pos == cookies.size()) + break; + } + } } - template <typename AllContext> - void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + void after_handle(request& req, response& res, context& ctx) { + for(auto& cookie:ctx.cookies_to_add) + { + res.add_header("Set-Cookie", cookie.first + "=" + cookie.second); + } } - } + }; /* App<CookieParser, AnotherJarMW> app; diff --git a/include/parser.h b/include/parser.h index 1b8240f..869061c 100644 --- a/include/parser.h +++ b/include/parser.h @@ -31,7 +31,6 @@ namespace crow case 0: if (!self->header_value.empty()) { - boost::algorithm::to_lower(self->header_field); self->headers.emplace(std::move(self->header_field), std::move(self->header_value)); } self->header_field.assign(at, at+length); @@ -63,7 +62,6 @@ namespace crow HTTPParser* self = static_cast<HTTPParser*>(self_); if (!self->header_field.empty()) { - boost::algorithm::to_lower(self->header_field); self->headers.emplace(std::move(self->header_field), std::move(self->header_value)); } self->process_header(); @@ -144,7 +142,7 @@ namespace crow int header_building_state = 0; std::string header_field; std::string header_value; - std::unordered_map<std::string, std::string> headers; + ci_map headers; std::string body; Handler* handler_; diff --git a/tests/unittest.cpp b/tests/unittest.cpp index a69e640..41387fa 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -7,6 +7,8 @@ #include "crow.h" #include "json.h" #include "mustache.h" +#include "middleware.h" + using namespace std; using namespace crow; @@ -187,7 +189,7 @@ TEST(RoutingTest) response res; req.url = "/4/5000/3/-2.71828/hellhere"; - req.headers["TestHeader"] = "Value"; + req.add_header("TestHeader", "Value"); app.handle(req, res); @@ -203,7 +205,7 @@ TEST(RoutingTest) response res; req.url = "/5/-5/999/3.141592/hello_there/a/b/c/d"; - req.headers["TestHeader"] = "Value"; + req.add_header("TestHeader", "Value"); app.handle(req, res); @@ -676,6 +678,45 @@ TEST(middleware_context) server.stop(); } +TEST(middleware_cookieparser) +{ + static char buf[2048]; + + App<CookieParser> app; + + std::string value1; + std::string value2; + + CROW_ROUTE(app, "/")([&](const request& req){ + { + auto& ctx = app.get_context<CookieParser>(req); + value1 = ctx.get_cookie("key1"); + value2 = ctx.get_cookie("key2"); + } + + return ""; + }); + + decltype(app)::server_t server(&app, 45451); + auto _ = async(launch::async, [&]{server.run();}); + std::string sendmsg = "GET /\r\nCookie: key1=value1; key2=\"val\\\"ue2\"\r\n\r\n"; + asio::io_service is; + { + asio::ip::tcp::socket c(is); + c.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), 45451)); + + c.send(asio::buffer(sendmsg)); + + c.receive(asio::buffer(buf, 2048)); + c.close(); + } + { + ASSERT_EQUAL("value1", value1); + ASSERT_EQUAL("val\"ue2", value2); + } + server.stop(); +} + int main() { return testmain(); |