diff options
author | ipknHama <ipknhama@gmail.com> | 2014-09-08 07:07:53 +0900 |
---|---|---|
committer | ipknHama <ipknhama@gmail.com> | 2014-09-08 07:07:53 +0900 |
commit | ab1063c046b363a37ccaf91c7dfb1fecd279be36 (patch) | |
tree | 87b3a8cc331dc2a7130e0e6104fe1a4d1127d0ff | |
parent | 2748e35430b9a4aaf64dfbd626d819f0fc5eedd2 (diff) | |
download | crow-ab1063c046b363a37ccaf91c7dfb1fecd279be36.tar.gz crow-ab1063c046b363a37ccaf91c7dfb1fecd279be36.zip |
complete middleware implementation
-rw-r--r-- | amalgamate/crow_all.h | 492 | ||||
-rw-r--r-- | include/crow.h | 2 | ||||
-rw-r--r-- | include/http_connection.h | 89 | ||||
-rw-r--r-- | include/middleware_context.h | 25 | ||||
-rw-r--r-- | include/utility.h | 11 | ||||
-rw-r--r-- | tests/unittest.cpp | 146 |
6 files changed, 673 insertions, 92 deletions
diff --git a/amalgamate/crow_all.h b/amalgamate/crow_all.h index 5742e54..edcfde8 100644 --- a/amalgamate/crow_all.h +++ b/amalgamate/crow_all.h @@ -2031,6 +2031,94 @@ namespace crow +#pragma once + +#include <boost/asio.hpp> +#include <deque> +#include <functional> +#include <chrono> +#include <thread> + +namespace crow +{ + namespace detail + { + // fast timer queue for fixed tick value. + class dumb_timer_queue + { + public: + // tls based queue to avoid locking + static dumb_timer_queue& get_current_dumb_timer_queue() + { + thread_local dumb_timer_queue q; + return q; + } + + using key = std::pair<dumb_timer_queue*, int>; + + void cancel(key& k) + { + auto self = k.first; + k.first = nullptr; + if (!self) + return; + + unsigned int index = (unsigned int)(k.second - self->step_); + if (index < self->dq_.size()) + self->dq_[index].second = nullptr; + } + + key add(std::function<void()> f) + { + dq_.emplace_back(std::chrono::steady_clock::now(), std::move(f)); + int ret = step_+dq_.size()-1; + + CROW_LOG_DEBUG << "timer add inside: " << this << ' ' << ret ; + return {this, ret}; + } + + void process() + { + if (!io_service_) + return; + + auto now = std::chrono::steady_clock::now(); + while(!dq_.empty()) + { + auto& x = dq_.front(); + if (now - x.first < std::chrono::seconds(tick)) + break; + if (x.second) + { + CROW_LOG_DEBUG << "timer call: " << this << ' ' << step_; + // we know that timer handlers are very simple currenty; call here + x.second(); + } + dq_.pop_front(); + step_++; + } + } + + void set_io_service(boost::asio::io_service& io_service) + { + io_service_ = &io_service; + } + + private: + dumb_timer_queue() noexcept + { + } + + int tick{5}; + boost::asio::io_service* io_service_{}; + std::deque<std::pair<decltype(std::chrono::steady_clock::now()), std::function<void()>>> dq_; + int step_{}; + }; + } +} + + + /* merged revision: 5b951d74bd66ec9d38448e0a85b1cf8b85d97db3 */ /* Copyright Joyent, Inc. and other Node contributors. All rights reserved. * @@ -4887,13 +4975,21 @@ namespace crow }; } -#define CROW_LOG_CRITICAL crow::logger("CRITICAL", crow::LogLevel::CRITICAL) -#define CROW_LOG_ERROR crow::logger("ERROR ", crow::LogLevel::ERROR) -#define CROW_LOG_WARNING crow::logger("WARNING ", crow::LogLevel::WARNING) -#define CROW_LOG_INFO crow::logger("INFO ", crow::LogLevel::INFO) -#define CROW_LOG_DEBUG crow::logger("DEBUG ", crow::LogLevel::DEBUG) - - +#define CROW_LOG_CRITICAL \ + if (crow::logger::get_current_log_level() <= crow::LogLevel::CRITICAL) \ + crow::logger("CRITICAL", crow::LogLevel::CRITICAL) +#define CROW_LOG_ERROR \ + if (crow::logger::get_current_log_level() <= crow::LogLevel::ERROR) \ + crow::logger("ERROR ", crow::LogLevel::ERROR) +#define CROW_LOG_WARNING \ + if (crow::logger::get_current_log_level() <= crow::LogLevel::WARNING) \ + crow::logger("WARNING ", crow::LogLevel::WARNING) +#define CROW_LOG_INFO \ + if (crow::logger::get_current_log_level() <= crow::LogLevel::INFO) \ + crow::logger("INFO ", crow::LogLevel::INFO) +#define CROW_LOG_DEBUG \ + if (crow::logger::get_current_log_level() <= crow::LogLevel::DEBUG) \ + crow::logger("DEBUG ", crow::LogLevel::DEBUG) @@ -4902,6 +4998,8 @@ namespace crow #include <cstdint> #include <stdexcept> +#include <tuple> +#include <type_traits> namespace crow { @@ -5111,6 +5209,77 @@ template <typename F, typename Set> using type = S<>; }; + template <typename ... T> + struct last_element_type + { + using type = typename std::tuple_element<sizeof...(T)-1, std::tuple<T...>>::type; + }; + + + template <> + struct last_element_type<> + { + }; + + + // from http://stackoverflow.com/questions/13072359/c11-compile-time-array-with-logarithmic-evaluation-depth + template<class T> using Invoke = typename T::type; + + template<unsigned...> struct seq{ using type = seq; }; + + template<class S1, class S2> struct concat; + + template<unsigned... I1, unsigned... I2> + struct concat<seq<I1...>, seq<I2...>> + : seq<I1..., (sizeof...(I1)+I2)...>{}; + + template<class S1, class S2> + using Concat = Invoke<concat<S1, S2>>; + + template<unsigned N> struct gen_seq; + template<unsigned N> using GenSeq = Invoke<gen_seq<N>>; + + template<unsigned N> + struct gen_seq : Concat<GenSeq<N/2>, GenSeq<N - N/2>>{}; + + template<> struct gen_seq<0> : seq<>{}; + template<> struct gen_seq<1> : seq<0>{}; + + template <typename Seq, typename Tuple> + struct pop_back_helper; + + template <unsigned ... N, typename Tuple> + struct pop_back_helper<seq<N...>, Tuple> + { + template <template <typename ... Args> class U> + using rebind = U<std::tuple_element<N, Tuple>...>; + }; + + template <typename ... T> + struct pop_back : public pop_back_helper<typename gen_seq<sizeof...(T)-1>::type, std::tuple<T...>> + { + }; + + template <> + struct pop_back<> + { + template <template <typename ... Args> class U> + using rebind = U<>; + }; + + // from http://stackoverflow.com/questions/2118541/check-if-c0x-parameter-pack-contains-a-type + template < typename Tp, typename... List > + struct contains : std::true_type {}; + + template < typename Tp, typename Head, typename... Rest > + struct contains<Tp, Head, Rest...> + : std::conditional< std::is_same<Tp, Head>::value, + std::true_type, + contains<Tp, Rest...> + >::type {}; + + template < typename Tp > + struct contains<Tp> : std::false_type {}; } } @@ -5256,6 +5425,8 @@ namespace crow std::string url; std::unordered_map<std::string, std::string> headers; std::string body; + + void* middleware_context; }; } @@ -5263,6 +5434,61 @@ namespace crow #pragma once + + + + +namespace crow +{ + class CookieParser + { + struct context + { + std::unordered_map<std::string, std::string> jar; + }; + + template <typename AllContext> + void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + { + // ctx == all_ctx.bind<CookieParser>() + // ctx.jar[] = ; + } + + template <typename AllContext> + void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + { + } + } + + /* + App<CookieParser, AnotherJarMW> app; + A B C + A::context + int aa; + + ctx1 : public A::context + ctx2 : public ctx1, public B::context + ctx3 : public ctx2, public C::context + + C depends on A + + C::handle + context.aaa + + App::context : private CookieParser::contetx, ... + { + jar + + } + + SimpleApp + */ +} + + + +#pragma once + #include <string> #include <unordered_map> #include <boost/algorithm/string.hpp> @@ -5346,7 +5572,15 @@ namespace crow return 0; } HTTPParser(Handler* handler) : - settings_ { + handler_(handler) + { + http_parser_init(this, HTTP_REQUEST); + } + + // return false on error + bool feed(const char* buffer, int length) + { + const static http_parser_settings settings_{ on_message_begin, on_url, nullptr, @@ -5355,23 +5589,15 @@ namespace crow on_headers_complete, on_body, on_message_complete, - }, - handler_(handler) - { - http_parser_init(this, HTTP_REQUEST); - } + }; - // return false on error - bool feed(const char* buffer, int length) - { int nparsed = http_parser_execute(this, &settings_, buffer, length); return nparsed == length; } bool done() { - int nparsed = http_parser_execute(this, &settings_, nullptr, 0); - return nparsed == 0; + return feed(nullptr, 0); } void clear() @@ -5411,8 +5637,6 @@ namespace crow std::unordered_map<std::string, std::string> headers; std::string body; - http_parser_settings settings_; - Handler* handler_; }; } @@ -5441,6 +5665,8 @@ namespace crow + + namespace crow { using namespace boost; @@ -5449,16 +5675,14 @@ namespace crow static int connectionCount; #endif template <typename Handler> - class Connection : public std::enable_shared_from_this<Connection<Handler>> + class Connection { public: - Connection(tcp::socket&& socket, Handler* handler, const std::string& server_name) - : socket_(std::move(socket)), + Connection(boost::asio::io_service& io_service, Handler* handler, const std::string& server_name) + : socket_(io_service), handler_(handler), parser_(this), - server_name_(server_name), - deadline_(socket_.get_io_service()), - address_str_(boost::lexical_cast<std::string>(socket_.remote_endpoint())) + server_name_(server_name) { #ifdef CROW_ENABLE_DEBUG connectionCount ++; @@ -5469,15 +5693,21 @@ namespace crow ~Connection() { res.complete_request_handler_ = nullptr; + cancel_deadline_timer(); #ifdef CROW_ENABLE_DEBUG connectionCount --; CROW_LOG_DEBUG << "Connection closed, total " << connectionCount << ", " << this; #endif } + tcp::socket& socket() + { + return socket_; + } + void start() { - auto self = this->shared_from_this(); + //auto self = this->shared_from_this(); start_deadline(); do_read(); @@ -5497,6 +5727,7 @@ namespace crow void handle() { + cancel_deadline_timer(); bool is_invalid_request = false; request req = parser_.to_request(); @@ -5518,15 +5749,13 @@ namespace crow } } - CROW_LOG_INFO << "Request: " << address_str_ << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' ' + CROW_LOG_INFO << "Request: " << boost::lexical_cast<std::string>(socket_.remote_endpoint()) << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' ' << method_name(req.method) << " " << req.url; if (!is_invalid_request) { - deadline_.cancel(); - auto self = this->shared_from_this(); - res.complete_request_handler_ = [self]{ self->complete_request(); }; + res.complete_request_handler_ = [this]{ this->complete_request(); }; res.is_alive_helper_ = [this]()->bool{ return socket_.is_open(); }; handler_->handle(req, res); } @@ -5540,12 +5769,15 @@ namespace crow { CROW_LOG_INFO << "Response: " << this << ' ' << res.code << ' ' << close_connection_; - auto self = this->shared_from_this(); + //auto self = this->shared_from_this(); res.complete_request_handler_ = nullptr; - + if (!socket_.is_open()) + { + //CROW_LOG_DEBUG << this << " delete (socket is closed) " << is_reading << ' ' << is_writing; + //delete this; return; - + } static std::unordered_map<int, std::string> statusCodes = { {200, "HTTP/1.1 200 OK\r\n"}, @@ -5657,74 +5889,100 @@ namespace crow void do_read() { - auto self = this->shared_from_this(); + //auto self = this->shared_from_this(); + is_reading = true; socket_.async_read_some(boost::asio::buffer(buffer_), - [self, this](const boost::system::error_code& ec, std::size_t bytes_transferred) + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) { bool error_while_reading = true; if (!ec) { bool ret = parser_.feed(buffer_.data(), bytes_transferred); - if (ret) + if (ret && socket_.is_open() && !close_connection_) { - do_read(); error_while_reading = false; } } if (error_while_reading) { - deadline_.cancel(); + cancel_deadline_timer(); parser_.done(); socket_.close(); + is_reading = false; + CROW_LOG_DEBUG << this << " from read(1)"; + check_destroy(); } else { start_deadline(); + do_read(); } }); } void do_write() { - auto self = this->shared_from_this(); + //auto self = this->shared_from_this(); + is_writing = true; boost::asio::async_write(socket_, buffers_, - [&, self](const boost::system::error_code& ec, std::size_t bytes_transferred) + [&](const boost::system::error_code& ec, std::size_t bytes_transferred) { + is_writing = false; if (!ec) { - start_deadline(); if (close_connection_) { socket_.close(); + CROW_LOG_DEBUG << this << " from write(1)"; + check_destroy(); } } + else + { + CROW_LOG_DEBUG << this << " from write(2)"; + check_destroy(); + } }); } + void check_destroy() + { + CROW_LOG_DEBUG << this << " is_reading " << is_reading << " is_writing " << is_writing; + if (!is_reading && !is_writing) + { + CROW_LOG_DEBUG << this << " delete (idle) "; + delete this; + } + } + + void cancel_deadline_timer() + { + CROW_LOG_DEBUG << this << " timer cancelled: " << timer_cancel_key_.first << ' ' << timer_cancel_key_.second; + detail::dumb_timer_queue::get_current_dumb_timer_queue().cancel(timer_cancel_key_); + } + void start_deadline(int timeout = 5) { - deadline_.expires_from_now(boost::posix_time::seconds(timeout)); - auto self = this->shared_from_this(); - deadline_.async_wait([self, this](const boost::system::error_code& ec) + auto& timer_queue = detail::dumb_timer_queue::get_current_dumb_timer_queue(); + cancel_deadline_timer(); + + timer_cancel_key_ = timer_queue.add([this] { - if (ec || !socket_.is_open()) + if (!socket_.is_open()) { return; } - bool is_deadline_passed = deadline_.expires_at() <= boost::asio::deadline_timer::traits_type::now(); - if (is_deadline_passed) - { - socket_.close(); - } + socket_.close(); }); + CROW_LOG_DEBUG << this << " timer added: " << timer_cancel_key_.first << ' ' << timer_cancel_key_.second; } private: tcp::socket socket_; Handler* handler_; - std::array<char, 8192> buffer_; + std::array<char, 4096> buffer_; HTTPParser<Connection> parser_; response res; @@ -5737,8 +5995,11 @@ namespace crow std::string content_length_; std::string date_str_; - boost::asio::deadline_timer deadline_; - std::string address_str_; + //boost::asio::deadline_timer deadline_; + detail::dumb_timer_queue::key timer_cancel_key_; + + bool is_reading{}; + bool is_writing{}; }; } @@ -5747,6 +6008,7 @@ namespace crow #pragma once +#include <boost/date_time/posix_time/posix_time.hpp> #include <boost/asio.hpp> #include <cstdint> #include <atomic> @@ -5761,6 +6023,8 @@ namespace crow + + namespace crow { using namespace boost; @@ -5772,46 +6036,80 @@ namespace crow public: Server(Handler* handler, uint16_t port, uint16_t concurrency = 1) : acceptor_(io_service_, tcp::endpoint(asio::ip::address(), port)), - socket_(io_service_), signals_(io_service_, SIGINT, SIGTERM), handler_(handler), concurrency_(concurrency), port_(port) { - do_accept(); } void run() { + if (concurrency_ < 0) + concurrency_ = 1; + + for(int i = 0; i < concurrency_; i++) + io_service_pool_.emplace_back(new boost::asio::io_service()); + std::vector<std::future<void>> v; for(uint16_t i = 0; i < concurrency_; i ++) v.push_back( - std::async(std::launch::async, [this]{io_service_.run();}) - ); - + std::async(std::launch::async, [this, i]{ + auto& timer_queue = detail::dumb_timer_queue::get_current_dumb_timer_queue(); + timer_queue.set_io_service(*io_service_pool_[i]); + boost::asio::deadline_timer timer(*io_service_pool_[i]); + timer.expires_from_now(boost::posix_time::seconds(1)); + std::function<void(const boost::system::error_code& ec)> handler; + handler = [&](const boost::system::error_code& ec){ + if (ec) + return; + timer_queue.process(); + timer.expires_from_now(boost::posix_time::seconds(1)); + timer.async_wait(handler); + }; + timer.async_wait(handler); + io_service_pool_[i]->run(); + })); CROW_LOG_INFO << server_name_ << " server is running, local port " << port_; signals_.async_wait( [&](const boost::system::error_code& error, int signal_number){ - io_service_.stop(); + stop(); }); - + + do_accept(); + + v.push_back(std::async(std::launch::async, [this]{ + io_service_.run(); + CROW_LOG_INFO << "Exiting."; + })); } void stop() { io_service_.stop(); + for(auto& io_service:io_service_pool_) + io_service->stop(); } private: + asio::io_service& pick_io_service() + { + // TODO load balancing + roundrobin_index_++; + if (roundrobin_index_ >= io_service_pool_.size()) + roundrobin_index_ = 0; + return *io_service_pool_[roundrobin_index_]; + } + void do_accept() { - acceptor_.async_accept(socket_, - [this](boost::system::error_code ec) + auto p = new Connection<Handler>(pick_io_service(), handler_, server_name_); + acceptor_.async_accept(p->socket(), + [this, p](boost::system::error_code ec) { if (!ec) { - auto p = std::make_shared<Connection<Handler>>(std::move(socket_), handler_, server_name_); p->start(); } do_accept(); @@ -5820,14 +6118,15 @@ namespace crow private: asio::io_service io_service_; + std::vector<std::unique_ptr<asio::io_service>> io_service_pool_; tcp::acceptor acceptor_; - tcp::socket socket_; boost::asio::signal_set signals_; Handler* handler_; uint16_t concurrency_{1}; std::string server_name_ = "Crow/0.1"; uint16_t port_; + unsigned int roundrobin_index_{}; }; } @@ -5965,9 +6264,8 @@ namespace crow ); return; } -#ifdef CROW_ENABLE_LOGGING - std::cerr << "ERROR cannot find handler" << std::endl; -#endif + CROW_LOG_DEBUG << "ERROR cannot find handler"; + // we already found matched url; this is server error cparams.res = response(500); } @@ -6499,6 +6797,42 @@ public: #pragma once + + + +namespace crow +{ + namespace detail + { + template <typename ... Middlewares> + struct partial_context + : public black_magic::pop_back<Middlewares...>::template rebind<partial_context> + , public black_magic::last_element_type<Middlewares...>::type::context + { + }; + + template <> + struct partial_context<> + { + }; + + template <typename ... Middlewares> + struct context : private partial_context<Middlewares...> + //struct context : private Middlewares::context... // simple but less type-safe + { + template <typename T> + typename T::context& get() + { + return static_cast<typename T::context&>(*this); + } + }; + } +} + + + +#pragma once + #include <string> #include <functional> #include <memory> @@ -6518,13 +6852,15 @@ public: -// TEST -#include <iostream> + + + #define CROW_ROUTE(app, url) app.route<crow::black_magic::get_parameter_tag(url)>(url) namespace crow { + template <typename ... Middlewares> class Crow { public: @@ -6582,13 +6918,27 @@ namespace crow router_.debug_print(); } + // middleware + using context_t = detail::context<Middlewares...>; + template <typename T> + T& get_middleware_context(request& req) + { + static_assert(black_magic::contains<T, Middlewares...>::value, "App doesn't have the specified middleware type."); + auto& ctx = *reinterpret_cast<context_t*>(req.middleware_context); + return ctx.get<T>(); + } + private: uint16_t port_ = 80; uint16_t concurrency_ = 1; + std::tuple<Middlewares...> middlewares_; + Router router_; }; - using App = Crow; + template <typename ... Middlewares> + using App = Crow<Middlewares...>; + using SimpleApp = Crow<>; }; diff --git a/include/crow.h b/include/crow.h index a4b82df..fdc5206 100644 --- a/include/crow.h +++ b/include/crow.h @@ -82,7 +82,7 @@ namespace crow // middleware using context_t = detail::context<Middlewares...>; template <typename T> - typename T::context& get_middleware_context(const request& req) + typename T::context& get_context(const request& req) { static_assert(black_magic::contains<T, Middlewares...>::value, "App doesn't have the specified middleware type."); auto& ctx = *reinterpret_cast<context_t*>(req.middleware_context); diff --git a/include/http_connection.h b/include/http_connection.h index a8631d9..3a9fe65 100644 --- a/include/http_connection.h +++ b/include/http_connection.h @@ -20,21 +20,53 @@ namespace crow { namespace detail { + template <typename MW, typename Context, typename ParentContext> + void before_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>()))* dummy = 0) + { + mw.before_handle(req, res, ctx.template get<MW>()); + } + + template <typename MW, typename Context, typename ParentContext> + void before_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>(), std::declval<Context&>))* dummy = 0) + { + mw.before_handle(req, res, ctx.template get<MW>(), parent_ctx); + } + + template <typename MW, typename Context, typename ParentContext> + void after_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>()))* dummy = 0) + { + mw.after_handle(req, res, ctx.template get<MW>()); + } + + template <typename MW, typename Context, typename ParentContext> + void after_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& parent_ctx, + decltype(std::declval<MW>().before_handle(std::declval<request&>(), std::declval<response&>(), std::declval<typename MW::context&>(), std::declval<Context&>))* dummy = 0) + { + mw.after_handle(req, res, ctx.template get<MW>(), parent_ctx); + } + template <int N, typename Context, typename Container, typename CurrentMW, typename ... Middlewares> bool middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx) { - // TODO cut ctx to partial_context<0..N-1> - std::get<N>(middlewares).before_handle(req, res, ctx.template get<CurrentMW>(), ctx); + using parent_context_t = typename Context::template partial<N-1>; + using current_context_t = typename Context::template partial<N>; + before_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + if (res.is_completed()) { - std::get<N>(middlewares).after_handle(req, res, ctx.template get<CurrentMW>(), ctx); + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); return true; } - if (middleware_call_helper<N+1, Context, Middlewares...>(middlewares, req, res, ctx)) + + if (middleware_call_helper<N+1, Context, Container, Middlewares...>(middlewares, req, res, ctx)) { - std::get<N>(middlewares).after_handle(req, res, ctx.template get<CurrentMW>(), ctx); + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); return true; } + return false; } @@ -43,6 +75,31 @@ namespace crow { return false; } + + template <int N, typename Context, typename Container> + typename std::enable_if<(N<0)>::type + after_handlers_call_helper(Container& middlewares, Context& context, request& req, response& res) + { + } + + template <int N, typename Context, typename Container> + typename std::enable_if<(N==0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res) + { + using parent_context_t = typename Context::template partial<N-1>; + using current_context_t = typename Context::template partial<N>; + using CurrentMW = typename std::tuple_element<N, typename std::remove_reference<Container>::type>::type; + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + } + + template <int N, typename Context, typename Container> + typename std::enable_if<(N>0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res) + { + using parent_context_t = typename Context::template partial<N-1>; + using current_context_t = typename Context::template partial<N>; + using CurrentMW = typename std::tuple_element<N, typename std::remove_reference<Container>::type>::type; + after_handler_call<CurrentMW, Context, parent_context_t>(std::get<N>(middlewares), req, res, ctx, static_cast<parent_context_t&>(ctx)); + after_handlers_call_helper<N-1, Context, Container>(middlewares, ctx, req, res); + } } using namespace boost; @@ -112,7 +169,8 @@ namespace crow cancel_deadline_timer(); bool is_invalid_request = false; - request req = parser_.to_request(); + req_ = std::move(parser_.to_request()); + request& req = req_; if (parser_.check_version(1, 0)) { // HTTP/1.0 @@ -138,17 +196,24 @@ namespace crow need_to_call_after_handlers_ = false; if (!is_invalid_request) { - res.complete_request_handler_ = [this]{ this->complete_request(); }; + res.complete_request_handler_ = []{}; res.is_alive_helper_ = [this]()->bool{ return socket_.is_open(); }; + ctx_ = detail::context<Middlewares...>(); req.middleware_context = (void*)&ctx_; detail::middleware_call_helper<0, decltype(ctx_), decltype(middlewares_), Middlewares...>(middlewares_, req, res, ctx_); + CROW_LOG_DEBUG << "ALATDA " << req.url; if (!res.completed_) { + res.complete_request_handler_ = [this]{ this->complete_request(); }; need_to_call_after_handlers_ = true; handler_->handle(req, res); } + else + { + complete_request(); + } } else { @@ -158,11 +223,16 @@ namespace crow void complete_request() { - CROW_LOG_INFO << "Response: " << this << ' ' << res.code << ' ' << close_connection_; + CROW_LOG_INFO << "Response: " << this << ' ' << req_.url << ' ' << res.code << ' ' << close_connection_; if (need_to_call_after_handlers_) { - // TODO call all of after_handlers + // call all after_handler of middlewares + detail::after_handlers_call_helper< + ((int)sizeof...(Middlewares)-1), + decltype(ctx_), + decltype(middlewares_)> + (middlewares_, ctx_, req_, res); } //auto self = this->shared_from_this(); @@ -381,6 +451,7 @@ namespace crow boost::array<char, 4096> buffer_; HTTPParser<Connection> parser_; + request req_; response res; bool close_connection_ = false; diff --git a/include/middleware_context.h b/include/middleware_context.h index 6dbf923..980a821 100644 --- a/include/middleware_context.h +++ b/include/middleware_context.h @@ -11,22 +11,47 @@ namespace crow : public black_magic::pop_back<Middlewares...>::template rebind<partial_context> , public black_magic::last_element_type<Middlewares...>::type::context { + using parent_context = typename black_magic::pop_back<Middlewares...>::template rebind<::crow::detail::partial_context>; + template <int N> + using partial = typename std::conditional<N == sizeof...(Middlewares)-1, partial_context, typename parent_context::template partial<N>>::type; + + template <typename T> + typename T::context& get() + { + return static_cast<typename T::context&>(*this); + } }; template <> struct partial_context<> { + template <int> + using partial = partial_context; }; + template <int N, typename Context, typename Container, typename CurrentMW, typename ... Middlewares> + bool middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx); + template <typename ... Middlewares> struct context : private partial_context<Middlewares...> //struct context : private Middlewares::context... // simple but less type-safe { + template <int N, typename Context, typename Container> + friend typename std::enable_if<(N==0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res); + template <int N, typename Context, typename Container> + friend typename std::enable_if<(N>0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res); + + template <int N, typename Context, typename Container, typename CurrentMW, typename ... Middlewares2> + friend bool middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx); + template <typename T> typename T::context& get() { return static_cast<typename T::context&>(*this); } + + template <int N> + using partial = typename partial_context<Middlewares...>::template partial<N>; }; } } diff --git a/include/utility.h b/include/utility.h index 35ea848..b2e61f1 100644 --- a/include/utility.h +++ b/include/utility.h @@ -256,12 +256,14 @@ template <typename F, typename Set> struct pop_back_helper<seq<N...>, Tuple> { template <template <typename ... Args> class U> - using rebind = U<std::tuple_element<N, Tuple>...>; + using rebind = U<typename std::tuple_element<N, Tuple>::type...>; }; template <typename ... T> - struct pop_back : public pop_back_helper<typename gen_seq<sizeof...(T)-1>::type, std::tuple<T...>> + struct pop_back //: public pop_back_helper<typename gen_seq<sizeof...(T)-1>::type, std::tuple<T...>> { + template <template <typename ... Args> class U> + using rebind = typename pop_back_helper<typename gen_seq<sizeof...(T)-1>::type, std::tuple<T...>>::template rebind<U>; }; template <> @@ -284,5 +286,10 @@ template <typename F, typename Set> template < typename Tp > struct contains<Tp> : std::false_type {}; + + template <typename T> + struct empty_context + { + }; } } diff --git a/tests/unittest.cpp b/tests/unittest.cpp index c84e3ed..a69e640 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -479,6 +479,14 @@ int testmain() return failed ? -1 : 0; } +TEST(black_magic) +{ + using namespace black_magic; + static_assert(std::is_same<void, last_element_type<int, char, void>::type>::value, "last_element_type"); + static_assert(std::is_same<char, pop_back<int, char, void>::rebind<last_element_type>::type>::value, "pop_back"); + static_assert(std::is_same<int, pop_back<int, char, void>::rebind<pop_back>::rebind<last_element_type>::type>::value, "pop_back"); +} + struct NullMiddleware { struct context {}; @@ -492,12 +500,25 @@ struct NullMiddleware {} }; +struct NullSimpleMiddleware +{ + struct context {}; + + void before_handle(request& req, response& res, context& ctx) + {} + + void after_handle(request& req, response& res, context& ctx) + {} +}; + TEST(middleware_simple) { - App<NullMiddleware> app; + App<NullMiddleware, NullSimpleMiddleware> app; + decltype(app)::server_t server(&app, 45451); CROW_ROUTE(app, "/")([&](const crow::request& req) { - app.get_middleware_context<NullMiddleware>(req); + app.get_context<NullMiddleware>(req); + app.get_context<NullSimpleMiddleware>(req); return ""; }); } @@ -519,21 +540,98 @@ struct IntSettingMiddleware } }; +std::vector<std::string> test_middleware_context_vector; + +struct FirstMW +{ + struct context + { + std::vector<string> v; + }; + + void before_handle(request& req, response& res, context& ctx) + { + ctx.v.push_back("1 before"); + } + + void after_handle(request& req, response& res, context& ctx) + { + ctx.v.push_back("1 after"); + test_middleware_context_vector = ctx.v; + } +}; + +struct SecondMW +{ + struct context {}; + template <typename AllContext> + void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + { + all_ctx.template get<FirstMW>().v.push_back("2 before"); + if (req.url == "/break") + res.end(); + } + + template <typename AllContext> + void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + { + all_ctx.template get<FirstMW>().v.push_back("2 after"); + } +}; + +struct ThirdMW +{ + struct context {}; + template <typename AllContext> + void before_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + { + all_ctx.template get<FirstMW>().v.push_back("3 before"); + } + + template <typename AllContext> + void after_handle(request& req, response& res, context& ctx, AllContext& all_ctx) + { + all_ctx.template get<FirstMW>().v.push_back("3 after"); + } +}; + TEST(middleware_context) { + static char buf[2048]; - App<IntSettingMiddleware> app; - Server<decltype(app), IntSettingMiddleware> server(&app, 45451); - auto _ = async(launch::async, [&]{server.run();}); - std::string sendmsg = "GET /\r\n\r\n"; + // SecondMW depends on FirstMW (it uses all_ctx.get<FirstMW>) + // so it leads to compile error if we remove FirstMW from definition + // App<IntSettingMiddleware, SecondMW> app; + // or change the order of FirstMW and SecondMW + // App<IntSettingMiddleware, SecondMW, FirstMW> app; + + App<IntSettingMiddleware, FirstMW, SecondMW, ThirdMW> app; int x{}; CROW_ROUTE(app, "/")([&](const request& req){ - auto& ctx = app.get_middleware_context<IntSettingMiddleware>(req); - x = ctx.val; + { + auto& ctx = app.get_context<IntSettingMiddleware>(req); + x = ctx.val; + } + { + auto& ctx = app.get_context<FirstMW>(req); + ctx.v.push_back("handle"); + } return ""; }); + CROW_ROUTE(app, "/break")([&](const request& req){ + { + auto& ctx = app.get_context<FirstMW>(req); + ctx.v.push_back("handle"); + } + + return ""; + }); + + decltype(app)::server_t server(&app, 45451); + auto _ = async(launch::async, [&]{server.run();}); + std::string sendmsg = "GET /\r\n\r\n"; asio::io_service is; { asio::ip::tcp::socket c(is); @@ -543,8 +641,38 @@ TEST(middleware_context) c.send(asio::buffer(sendmsg)); c.receive(asio::buffer(buf, 2048)); + c.close(); + } + { + auto& out = test_middleware_context_vector; + ASSERT_EQUAL(1, x); + ASSERT_EQUAL(7, out.size()); + ASSERT_EQUAL("1 before", out[0]); + ASSERT_EQUAL("2 before", out[1]); + ASSERT_EQUAL("3 before", out[2]); + ASSERT_EQUAL("handle", out[3]); + ASSERT_EQUAL("3 after", out[4]); + ASSERT_EQUAL("2 after", out[5]); + ASSERT_EQUAL("1 after", out[6]); + } + std::string sendmsg2 = "GET /break\r\n\r\n"; + { + asio::ip::tcp::socket c(is); + c.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), 45451)); + + c.send(asio::buffer(sendmsg2)); + + c.receive(asio::buffer(buf, 2048)); + c.close(); + } + { + auto& out = test_middleware_context_vector; + ASSERT_EQUAL(4, out.size()); + ASSERT_EQUAL("1 before", out[0]); + ASSERT_EQUAL("2 before", out[1]); + ASSERT_EQUAL("2 after", out[2]); + ASSERT_EQUAL("1 after", out[3]); } - ASSERT_EQUAL(1, x); server.stop(); } |