diff --git a/cpp/csp/adapters/websocket/CMakeLists.txt b/cpp/csp/adapters/websocket/CMakeLists.txt index 402d01b69..6bd7d7ae3 100644 --- a/cpp/csp/adapters/websocket/CMakeLists.txt +++ b/cpp/csp/adapters/websocket/CMakeLists.txt @@ -5,7 +5,9 @@ set(WS_CLIENT_HEADER_FILES ClientInputAdapter.h ClientOutputAdapter.h ClientHeaderUpdateAdapter.h + ClientConnectionRequestAdapter.h WebsocketEndpoint.h + WebsocketEndpointManager.h ${WEBSOCKET_HEADER} ) @@ -14,7 +16,9 @@ set(WS_CLIENT_SOURCE_FILES ClientInputAdapter.cpp ClientOutputAdapter.cpp ClientHeaderUpdateAdapter.cpp + ClientConnectionRequestAdapter.cpp WebsocketEndpoint.cpp + WebsocketEndpointManager.cpp ${WS_CLIENT_HEADER_FILES} ${WEBSOCKET_SOURCE} ) diff --git a/cpp/csp/adapters/websocket/ClientAdapterManager.cpp b/cpp/csp/adapters/websocket/ClientAdapterManager.cpp index 423f2a234..ab609462b 100644 --- a/cpp/csp/adapters/websocket/ClientAdapterManager.cpp +++ b/cpp/csp/adapters/websocket/ClientAdapterManager.cpp @@ -1,124 +1,59 @@ #include -namespace csp { - -INIT_CSP_ENUM( adapters::websocket::ClientStatusType, - "ACTIVE", - "GENERIC_ERROR", - "CONNECTION_FAILED", - "CLOSED", - "MESSAGE_SEND_FAIL", -); - -} - -// With TLS namespace csp::adapters::websocket { ClientAdapterManager::ClientAdapterManager( Engine* engine, const Dictionary & properties ) : AdapterManager( engine ), - m_active( false ), - m_shouldRun( false ), - m_endpoint( std::make_unique( properties ) ), - m_inputAdapter( nullptr ), - m_outputAdapter( nullptr ), - m_updateAdapter( nullptr ), - m_thread( nullptr ), - m_properties( properties ) -{ }; + m_properties( properties ) +{ } ClientAdapterManager::~ClientAdapterManager() -{ }; +{ } -void ClientAdapterManager::start( DateTime starttime, DateTime endtime ) -{ - AdapterManager::start( starttime, endtime ); - - m_shouldRun = true; - m_endpoint -> setOnOpen( - [ this ]() { - m_active = true; - pushStatus( StatusLevel::INFO, ClientStatusType::ACTIVE, "Connected successfully" ); - } - ); - m_endpoint -> setOnFail( - [ this ]( const std::string& reason ) { - std::stringstream ss; - ss << "Connection Failure: " << reason; - m_active = false; - pushStatus( StatusLevel::ERROR, ClientStatusType::CONNECTION_FAILED, ss.str() ); - } - ); - if( m_inputAdapter ) { - m_endpoint -> setOnMessage( - [ this ]( void* c, size_t t ) { - PushBatch batch( m_engine -> rootEngine() ); - m_inputAdapter -> processMessage( c, t, &batch ); - } - ); - } else { - // if a user doesn't call WebsocketAdapterManager.subscribe, no inputadapter will be created - // but we still need something to avoid on_message_cb not being set in the endpoint. - m_endpoint -> setOnMessage( []( void* c, size_t t ){} ); - } - m_endpoint -> setOnClose( - [ this ]() { - m_active = false; - pushStatus( StatusLevel::INFO, ClientStatusType::CLOSED, "Connection closed" ); - } - ); - m_endpoint -> setOnSendFail( - [ this ]( const std::string& s ) { - std::stringstream ss; - ss << "Failed to send: " << s; - pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, ss.str() ); - } - ); +WebsocketEndpointManager* ClientAdapterManager::getWebsocketManager(){ + if( m_endpointManager == nullptr ) + return nullptr; + return m_endpointManager.get(); +} - m_thread = std::make_unique( [ this ]() { - while( m_shouldRun ) - { - m_endpoint -> run(); - m_active = false; - if( m_shouldRun ) sleep( m_properties.get( "reconnect_interval" ) ); - } - }); -}; +void ClientAdapterManager::start(DateTime starttime, DateTime endtime) { + AdapterManager::start(starttime, endtime); + if (m_endpointManager != nullptr) + m_endpointManager -> start(starttime, endtime); +} void ClientAdapterManager::stop() { AdapterManager::stop(); - - m_shouldRun=false; - if( m_active ) m_endpoint->stop(); - if( m_thread ) m_thread->join(); -}; + if (m_endpointManager != nullptr) + m_endpointManager -> stop(); +} PushInputAdapter* ClientAdapterManager::getInputAdapter(CspTypePtr & type, PushMode pushMode, const Dictionary & properties) -{ - if (m_inputAdapter == nullptr) - { - m_inputAdapter = m_engine -> createOwnedObject( - // m_engine, - type, - pushMode, - properties - ); - } - return m_inputAdapter; -}; +{ + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique(this, m_properties, m_engine); + return m_endpointManager -> getInputAdapter( type, pushMode, properties ); +} -OutputAdapter* ClientAdapterManager::getOutputAdapter() +OutputAdapter* ClientAdapterManager::getOutputAdapter( const Dictionary & properties ) { - if (m_outputAdapter == nullptr) m_outputAdapter = m_engine -> createOwnedObject(*m_endpoint); - - return m_outputAdapter; + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique(this, m_properties, m_engine); + return m_endpointManager -> getOutputAdapter( properties ); } OutputAdapter * ClientAdapterManager::getHeaderUpdateAdapter() { - if (m_updateAdapter == nullptr) m_updateAdapter = m_engine -> createOwnedObject( m_endpoint -> getProperties() ); + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique(this, m_properties, m_engine); + return m_endpointManager -> getHeaderUpdateAdapter(); +} - return m_updateAdapter; +OutputAdapter * ClientAdapterManager::getConnectionRequestAdapter( const Dictionary & properties ) +{ + if (m_endpointManager == nullptr) + m_endpointManager = std::make_unique(this, m_properties, m_engine); + return m_endpointManager -> getConnectionRequestAdapter( properties ); } DateTime ClientAdapterManager::processNextSimTimeSlice( DateTime time ) diff --git a/cpp/csp/adapters/websocket/ClientAdapterManager.h b/cpp/csp/adapters/websocket/ClientAdapterManager.h index 62577d769..101e05f98 100644 --- a/cpp/csp/adapters/websocket/ClientAdapterManager.h +++ b/cpp/csp/adapters/websocket/ClientAdapterManager.h @@ -2,8 +2,8 @@ #define _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_ADAPTERMGR_H #include +#include #include -#include #include #include #include @@ -15,30 +15,15 @@ #include #include #include +#include +#include namespace csp::adapters::websocket { using namespace csp; -struct WebsocketClientStatusTypeTraits -{ - enum _enum : unsigned char - { - ACTIVE = 0, - GENERIC_ERROR = 1, - CONNECTION_FAILED = 2, - CLOSED = 3, - MESSAGE_SEND_FAIL = 4, - - NUM_TYPES - }; - -protected: - _enum m_value; -}; - -using ClientStatusType = Enum; +class WebsocketEndpointManager; class ClientAdapterManager final : public AdapterManager { @@ -57,23 +42,17 @@ class ClientAdapterManager final : public AdapterManager void stop() override; + WebsocketEndpointManager* getWebsocketManager(); PushInputAdapter * getInputAdapter( CspTypePtr & type, PushMode pushMode, const Dictionary & properties ); - OutputAdapter * getOutputAdapter(); + OutputAdapter * getOutputAdapter( const Dictionary & properties ); OutputAdapter * getHeaderUpdateAdapter(); + OutputAdapter * getConnectionRequestAdapter( const Dictionary & properties ); DateTime processNextSimTimeSlice( DateTime time ) override; private: - // need some client info - - bool m_active; - bool m_shouldRun; - std::unique_ptr m_endpoint; - ClientInputAdapter* m_inputAdapter; - ClientOutputAdapter* m_outputAdapter; - ClientHeaderUpdateOutputAdapter* m_updateAdapter; - std::unique_ptr m_thread; Dictionary m_properties; + std::unique_ptr m_endpointManager; }; } diff --git a/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp new file mode 100644 index 000000000..2bef1ffb5 --- /dev/null +++ b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.cpp @@ -0,0 +1,66 @@ +#include + +namespace csp::adapters::websocket { + +ClientConnectionRequestAdapter::ClientConnectionRequestAdapter( + Engine * engine, + WebsocketEndpointManager * websocketManager, + bool is_subscribe, + size_t caller_id, + boost::asio::strand& strand + +) : OutputAdapter( engine ), + m_websocketManager( websocketManager ), + m_strand( strand ), + m_isSubscribe( is_subscribe ), + m_callerId( caller_id ), + m_checkPerformed( is_subscribe ? false : true ) // we only need to check for pruned input adapters +{} + +void ClientConnectionRequestAdapter::executeImpl() +{ + // One-time check for pruned status + if (unlikely(!m_checkPerformed)) { + m_isPruned = m_websocketManager->adapterPruned(m_callerId); + m_checkPerformed = true; + } + + // Early return if pruned + if (unlikely(m_isPruned)) + return; + + std::vector properties_list; + for (auto& request : input()->lastValueTyped>()) { + if (!request->allFieldsSet()) + CSP_THROW(TypeError, "All fields must be set in InternalConnectionRequest"); + + Dictionary dict; + dict.update("host", request->host()); + dict.update("port", request->port()); + dict.update("route", request->route()); + dict.update("uri", request->uri()); + dict.update("use_ssl", request->use_ssl()); + dict.update("reconnect_interval", request->reconnect_interval()); + dict.update("persistent", request->persistent()); + + dict.update("headers", request -> headers() ); + dict.update("on_connect_payload", request->on_connect_payload()); + dict.update("action", request->action()); + dict.update("dynamic", request->dynamic()); + dict.update("binary", request->binary()); + + properties_list.push_back(std::move(dict)); + } + + // We intentionally post here, we want the thread running + // the strand to handle the connection request. We want to keep + // all updates to internal data structures at graph run-time + // to that thread. + boost::asio::post(m_strand, [this, properties_list=std::move(properties_list)]() { + for(const auto& conn_req: properties_list) { + m_websocketManager->handleConnectionRequest(conn_req, m_callerId, m_isSubscribe); + } + }); +}; + +} \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.h b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.h new file mode 100644 index 000000000..505ea1164 --- /dev/null +++ b/cpp/csp/adapters/websocket/ClientConnectionRequestAdapter.h @@ -0,0 +1,45 @@ +#ifndef _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_CONNECTIONREQUESTADAPTER_H +#define _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_CONNECTIONREQUESTADAPTER_H + +#include +#include +#include +#include +#include + +namespace csp::adapters::websocket +{ +using namespace csp::autogen; + +class ClientAdapterManager; +class WebsocketEndpointManager; + +class ClientConnectionRequestAdapter final: public OutputAdapter +{ +public: + ClientConnectionRequestAdapter( + Engine * engine, + WebsocketEndpointManager * websocketManager, + bool isSubscribe, + size_t callerId, + boost::asio::strand& strand + ); + + void executeImpl() override; + + const char * name() const override { return "WebsocketClientConnectionRequestAdapter"; } + +private: + WebsocketEndpointManager* m_websocketManager; + boost::asio::strand& m_strand; + bool m_isSubscribe; + size_t m_callerId; + bool m_checkPerformed; + bool m_isPruned{false}; + +}; + +} + + +#endif \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp index 995bd7314..c25a368b7 100644 --- a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.cpp @@ -2,19 +2,27 @@ namespace csp::adapters::websocket { +class WebsocketEndpointManager; + ClientHeaderUpdateOutputAdapter::ClientHeaderUpdateOutputAdapter( Engine * engine, - Dictionary& properties -) : OutputAdapter( engine ), m_properties( properties ) + WebsocketEndpointManager * mgr, + boost::asio::strand& strand +) : OutputAdapter( engine ), m_mgr( mgr ), m_strand( strand ) { }; void ClientHeaderUpdateOutputAdapter::executeImpl() { - DictionaryPtr headers = m_properties.get("headers"); - for( auto& update : input() -> lastValueTyped>() ) - { - if( update -> key_isSet() && update -> value_isSet() ) headers->update( update->key(), update->value() ); + Dictionary headers; + for (auto& update : input()->lastValueTyped>()) { + if (update->key_isSet() && update->value_isSet()) { + headers.update(update->key(), update->value()); + } } + boost::asio::post(m_strand, [this, headers=std::move(headers)]() { + auto endpoint = m_mgr -> getNonDynamicEndpoint(); + endpoint -> updateHeaders(std::move(headers)); + }); }; } \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h index d2c898a1e..88d0ec439 100644 --- a/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h +++ b/cpp/csp/adapters/websocket/ClientHeaderUpdateAdapter.h @@ -1,6 +1,7 @@ #ifndef _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_HEADERUPDATEADAPTER_H #define _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_HEADERUPDATEADAPTER_H +#include #include #include #include @@ -10,12 +11,15 @@ namespace csp::adapters::websocket { using namespace csp::autogen; +class WebsocketEndpointManager; + class ClientHeaderUpdateOutputAdapter final: public OutputAdapter { public: ClientHeaderUpdateOutputAdapter( Engine * engine, - Dictionary& properties + WebsocketEndpointManager * mgr, + boost::asio::strand& strand ); void executeImpl() override; @@ -23,7 +27,10 @@ class ClientHeaderUpdateOutputAdapter final: public OutputAdapter const char * name() const override { return "WebsocketClientHeaderUpdateAdapter"; } private: - Dictionary& m_properties; + WebsocketEndpointManager * m_mgr; + boost::asio::strand& m_strand; + + }; diff --git a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp index e4b0b7ff7..103b63aff 100644 --- a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp @@ -1,5 +1,4 @@ #include - namespace csp::adapters::websocket { @@ -7,8 +6,10 @@ ClientInputAdapter::ClientInputAdapter( Engine * engine, CspTypePtr & type, PushMode pushMode, - const Dictionary & properties -) : PushInputAdapter(engine, type, pushMode) + const Dictionary & properties, + bool dynamic +) : PushInputAdapter(engine, type, pushMode), + m_dynamic( dynamic ) { if( type -> type() != CspType::Type::STRUCT && type -> type() != CspType::Type::STRING ) @@ -21,20 +22,48 @@ ClientInputAdapter::ClientInputAdapter( if( !metaFieldMap.empty() && type -> type() != CspType::Type::STRUCT ) CSP_THROW( ValueError, "meta_field_map is not supported on non-struct types" ); } + if ( m_dynamic ){ + auto& actual_type = static_cast( *type ); + auto& nested_type = actual_type.meta()-> field( "msg" ) -> type(); - m_converter = adapters::utils::MessageStructConverterCache::instance().create( type, properties ); + m_converter = adapters::utils::MessageStructConverterCache::instance().create( nested_type, properties ); + } + else + m_converter = adapters::utils::MessageStructConverterCache::instance().create( type, properties ); }; -void ClientInputAdapter::processMessage( void* c, size_t t, PushBatch* batch ) +void ClientInputAdapter::processMessage( const std::string& source, void * c, size_t t, PushBatch* batch ) { + if ( m_dynamic ){ + auto& actual_type = static_cast( *dataType() ); + auto& nested_type = actual_type.meta()-> field( "msg" ) -> type(); + auto true_val = actual_type.meta() -> create(); + actual_type.meta()->field("uri")->setValue( true_val.get(), source ); - if( dataType() -> type() == CspType::Type::STRUCT ) - { - auto tick = m_converter -> asStruct( c, t ); - pushTick( std::move(tick), batch ); - } else if ( dataType() -> type() == CspType::Type::STRING ) - { - pushTick( std::string((char const*)c, t), batch ); + if( nested_type -> type() == CspType::Type::STRUCT ) + { + auto tick = m_converter -> asStruct( c, t ); + actual_type.meta()->field("msg")->setValue( true_val.get(), std::move(tick) ); + + pushTick( std::move(true_val), batch ); + } else if ( nested_type -> type() == CspType::Type::STRING ) + { + auto msg = std::string((char const*)c, t); + actual_type.meta()->field("msg")->setValue( true_val.get(), msg ); + + pushTick( std::move(true_val), batch ); + } + + } + else{ + if( dataType() -> type() == CspType::Type::STRUCT ) + { + auto tick = m_converter -> asStruct( c, t ); + pushTick( std::move(tick), batch ); + } else if ( dataType() -> type() == CspType::Type::STRING ) + { + pushTick( std::string((char const*)c, t), batch ); + } } } diff --git a/cpp/csp/adapters/websocket/ClientInputAdapter.h b/cpp/csp/adapters/websocket/ClientInputAdapter.h index bf3cb295f..0f5a4223b 100644 --- a/cpp/csp/adapters/websocket/ClientInputAdapter.h +++ b/cpp/csp/adapters/websocket/ClientInputAdapter.h @@ -16,17 +16,19 @@ class ClientInputAdapter final: public PushInputAdapter { Engine * engine, CspTypePtr & type, PushMode pushMode, - const Dictionary & properties + const Dictionary & properties, + bool dynamic ); - void processMessage( void* c, size_t t, PushBatch* batch ); + void processMessage( const std::string& source, void * c, size_t t, PushBatch* batch ); private: adapters::utils::MessageStructConverterPtr m_converter; + const bool m_dynamic; }; } -#endif // _IN_CSP_ADAPTERS_WEBSOCKETS_CLIENT_INPUTADAPTER_H \ No newline at end of file +#endif \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp b/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp index 3ef3c91ac..7b9bd83e1 100644 --- a/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientOutputAdapter.cpp @@ -4,14 +4,23 @@ namespace csp::adapters::websocket { ClientOutputAdapter::ClientOutputAdapter( Engine * engine, - WebsocketEndpoint& endpoint -) : OutputAdapter( engine ), m_endpoint( endpoint ) + WebsocketEndpointManager * websocketManager, + size_t caller_id, + net::io_context& ioc, + boost::asio::strand& strand +) : OutputAdapter( engine ), + m_websocketManager( websocketManager ), + m_callerId( caller_id ), + m_ioc( ioc ), + m_strand( strand ) { }; void ClientOutputAdapter::executeImpl() { const std::string & value = input() -> lastValueTyped(); - m_endpoint.send( value ); -}; + boost::asio::post(m_strand, [this, value=value]() { + m_websocketManager->send(value, m_callerId); + }); +} } \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/ClientOutputAdapter.h b/cpp/csp/adapters/websocket/ClientOutputAdapter.h index 905822e2f..d97bc8062 100644 --- a/cpp/csp/adapters/websocket/ClientOutputAdapter.h +++ b/cpp/csp/adapters/websocket/ClientOutputAdapter.h @@ -5,11 +5,13 @@ #include #include #include +#include namespace csp::adapters::websocket { class ClientAdapterManager; +class WebsocketEndpointManager; class ClientOutputAdapter final: public OutputAdapter { @@ -17,7 +19,11 @@ class ClientOutputAdapter final: public OutputAdapter public: ClientOutputAdapter( Engine * engine, - WebsocketEndpoint& endpoint + WebsocketEndpointManager * websocketManager, + size_t caller_id, + net::io_context& ioc, + boost::asio::strand& strand + // bool dynamic ); void executeImpl() override; @@ -25,7 +31,12 @@ class ClientOutputAdapter final: public OutputAdapter const char * name() const override { return "WebsocketClientOutputAdapter"; } private: - WebsocketEndpoint& m_endpoint; + WebsocketEndpointManager* m_websocketManager; + size_t m_callerId; + [[maybe_unused]] net::io_context& m_ioc; + boost::asio::strand& m_strand; + // bool m_dynamic; + // std::unordered_map>& m_endpoint_consumers; }; } diff --git a/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp b/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp index f503b8e98..93f6e8918 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp +++ b/cpp/csp/adapters/websocket/WebsocketEndpoint.cpp @@ -1,12 +1,22 @@ #include +#include namespace csp::adapters::websocket { using namespace csp; -WebsocketEndpoint::WebsocketEndpoint( +WebsocketEndpoint::WebsocketEndpoint( + net::io_context& ioc, Dictionary properties -) : m_properties(properties) -{ }; +) : m_properties( std::make_shared( std::move( properties ) ) ), + m_ioc(ioc) +{ + std::string headerProps = m_properties->get("headers"); + // Create new empty headers dictionary + auto headers = std::make_shared(); + m_properties->update("headers", headers); + // Update with any existing header properties + updateHeaders(headerProps); +} void WebsocketEndpoint::setOnOpen(void_cb on_open) { m_on_open = std::move(on_open); } void WebsocketEndpoint::setOnFail(string_cb on_fail) @@ -20,17 +30,16 @@ void WebsocketEndpoint::setOnSendFail(string_cb on_send_fail) void WebsocketEndpoint::run() { - - m_ioc.restart(); - if(m_properties.get("use_ssl")) { + // Owns this ioc object + if(m_properties->get("use_ssl")) { ssl::context ctx{ssl::context::sslv23}; ctx.set_verify_mode(ssl::context::verify_peer ); ctx.set_default_verify_paths(); - m_session = new WebsocketSessionTLS( + m_session = std::make_shared( m_ioc, ctx, - &m_properties, + m_properties, m_on_open, m_on_fail, m_on_message, @@ -38,9 +47,9 @@ void WebsocketEndpoint::run() m_on_send_fail ); } else { - m_session = new WebsocketSessionNoTLS( + m_session = std::make_shared( m_ioc, - &m_properties, + m_properties, m_on_open, m_on_fail, m_on_message, @@ -49,23 +58,57 @@ void WebsocketEndpoint::run() ); } m_session->run(); +} + +WebsocketEndpoint::~WebsocketEndpoint() { + try { + // Call stop but explicitly pass false to prevent io_context shutdown + stop(false); + } catch (...) { + // Ignore any exceptions during cleanup + } +} - m_ioc.run(); +void WebsocketEndpoint::stop( bool stop_ioc ) +{ + if( m_session ) m_session->stop(); + if( stop_ioc ) m_ioc.stop(); } -void WebsocketEndpoint::stop() -{ - m_ioc.stop(); - if(m_session) m_session->stop(); +void WebsocketEndpoint::updateHeaders(csp::Dictionary properties){ + DictionaryPtr headers = m_properties->get("headers"); + for (auto it = properties.begin(); it != properties.end(); ++it) { + std::string key = it.key(); + auto value = it.value(); + headers->update(key, std::move(value)); + } } +void WebsocketEndpoint::updateHeaders(const std::string& properties) { + if( properties.empty() ) + return; + DictionaryPtr headers = m_properties->get("headers"); + rapidjson::Document doc; + doc.Parse(properties.c_str()); + if (doc.IsObject()) { + // Windows builds complained with range loop + for (auto it = doc.MemberBegin(); it != doc.MemberEnd(); ++it) { + if (it->value.IsString()) { + std::string key = it->name.GetString(); + std::string value = it->value.GetString(); + headers->update(key, std::move(value)); + } + } + } +} -csp::Dictionary& WebsocketEndpoint::getProperties() { +std::shared_ptr WebsocketEndpoint::getProperties() { return m_properties; } void WebsocketEndpoint::send(const std::string& s) { if(m_session) m_session->send(s); } - +void WebsocketEndpoint::ping() +{ if(m_session) m_session->ping(); } } \ No newline at end of file diff --git a/cpp/csp/adapters/websocket/WebsocketEndpoint.h b/cpp/csp/adapters/websocket/WebsocketEndpoint.h index cfca08742..aafc237d1 100644 --- a/cpp/csp/adapters/websocket/WebsocketEndpoint.h +++ b/cpp/csp/adapters/websocket/WebsocketEndpoint.h @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include #include @@ -13,6 +16,7 @@ #include #include #include +#include namespace csp::adapters::websocket { using namespace csp; @@ -23,6 +27,7 @@ namespace net = boost::asio; // from namespace ssl = boost::asio::ssl; // from namespace websocket = beast::websocket; // from using tcp = boost::asio::ip::tcp; // from +using error_code = boost::system::error_code; //from using string_cb = std::function; using char_cb = std::function; @@ -30,7 +35,9 @@ using void_cb = std::function; class BaseWebsocketSession { public: + virtual ~BaseWebsocketSession() = default; virtual void stop() { }; + virtual void ping() { }; virtual void send( const std::string& ) { }; virtual void do_read() { }; virtual void do_write(const std::string& ) { }; @@ -38,24 +45,28 @@ class BaseWebsocketSession { }; template -class WebsocketSession : public BaseWebsocketSession { +class WebsocketSession : + public BaseWebsocketSession, + public std::enable_shared_from_this +{ public: WebsocketSession( net::io_context& ioc, - Dictionary* properties, - void_cb& on_open, - string_cb& on_fail, - char_cb& on_message, - void_cb& on_close, - string_cb& on_send_fail - ) : m_resolver( net::make_strand( ioc ) ), - m_properties( properties ), - m_on_open( on_open ), - m_on_fail( on_fail ), - m_on_message( on_message ), - m_on_close( on_close ), - m_on_send_fail( on_send_fail ) - { }; + std::shared_ptr properties, + void_cb on_open, + string_cb on_fail, + char_cb on_message, + void_cb on_close, + string_cb on_send_fail + ) : m_resolver(net::make_strand(ioc)), + m_properties(properties), + m_on_open(std::move(on_open)), + m_on_fail(std::move(on_fail)), + m_on_message(std::move(on_message)), + m_on_close(std::move(on_close)), + m_on_send_fail(std::move(on_send_fail)) + { } + ~WebsocketSession() override = default; Derived& derived(){ return static_cast(*this); } @@ -81,53 +92,66 @@ class WebsocketSession : public BaseWebsocketSession { } void do_read() override { + auto self = std::static_pointer_cast(this->shared_from_this()); derived().ws().async_read( - m_buffer, - [ this ]( beast::error_code ec, std::size_t bytes_transfered ) - { handle_message( ec, bytes_transfered ); } + self->m_buffer, + [ self ](beast::error_code ec, std::size_t bytes_transfered) { + self->handle_message(ec, bytes_transfered); + } ); } + void ping() override { + auto self = std::static_pointer_cast(this->shared_from_this()); + derived().ws().async_ping({}, + [ self ](beast::error_code ec) { + if(ec) self->m_on_send_fail("Failed to ping"); + }); + } + void stop() override { - derived().ws().async_close( websocket::close_code::normal, [ this ]( beast::error_code ec ) { - if(ec) CSP_THROW(RuntimeException, ec.message()); - m_on_close(); + auto self = std::static_pointer_cast(this->shared_from_this()); + derived().ws().async_close( websocket::close_code::normal, [ self ]( beast::error_code ec ) { + if(ec) self->m_on_fail(ec.message()); + self -> m_on_close(); }); } - void send( const std::string& s ) override + void send(const std::string& s) override { + auto self = std::static_pointer_cast(this->shared_from_this()); net::post( derived().ws().get_executor(), - [this, s]() + [ self, s]() { - m_queue.push_back(s); - if (m_queue.size() > 1) return; - do_write(m_queue.front()); + self->m_queue.push_back(s); + if (self->m_queue.size() > 1) return; + self->do_write(self->m_queue.front()); } ); } void do_write(const std::string& s) override { + auto self = std::static_pointer_cast(this->shared_from_this()); derived().ws().async_write( net::buffer(s), - [this](beast::error_code ec, std::size_t bytes_transfered) + [self](beast::error_code ec, std::size_t bytes_transfered) { - // add logging here? - m_queue.erase(m_queue.begin()); + self->m_queue.erase(self->m_queue.begin()); boost::ignore_unused(bytes_transfered); - if(ec) m_on_send_fail(ec.message()); - if(m_queue.size() >0) do_write(m_queue.front()); + if(ec) self->m_on_send_fail(ec.message()); + if(self->m_queue.size() > 0) + self->do_write(self->m_queue.front()); } ); - } +} public: tcp::resolver m_resolver; - Dictionary* m_properties; + std::shared_ptr m_properties; void_cb m_on_open; string_cb m_on_fail; char_cb m_on_message; @@ -142,7 +166,7 @@ class WebsocketSessionNoTLS final: public WebsocketSession properties, void_cb& on_open, string_cb& on_fail, char_cb& on_message, @@ -161,57 +185,60 @@ class WebsocketSessionNoTLS final: public WebsocketSession(this->shared_from_this()); m_resolver.async_resolve( m_properties->get("host").c_str(), m_properties->get("port").c_str(), - [this]( beast::error_code ec, tcp::resolver::results_type results ) { + [ self ]( beast::error_code ec, tcp::resolver::results_type results ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } // Set the timeout for the operation - beast::get_lowest_layer(m_ws).expires_after(std::chrono::seconds(5)); + beast::get_lowest_layer(self->m_ws).expires_after(std::chrono::seconds(5)); // Make the connection on the IP address we get from a lookup - beast::get_lowest_layer(m_ws).async_connect( + beast::get_lowest_layer(self->m_ws).async_connect( results, - [this]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) + [self]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) { // Turn off the timeout on the tcp_stream, because // the websocket stream has its own timeout system. if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - beast::get_lowest_layer(m_ws).expires_never(); + beast::get_lowest_layer(self->m_ws).expires_never(); - m_ws.set_option( + self->m_ws.set_option( websocket::stream_base::timeout::suggested( beast::role_type::client)); - m_ws.set_option(websocket::stream_base::decorator( - [this](websocket::request_type& req) + self->m_ws.set_option(websocket::stream_base::decorator( + [self](websocket::request_type& req) { - set_headers(req); + self -> set_headers(req); req.set(http::field::user_agent, "CSP WebsocketEndpoint"); } )); - std::string host_ = m_properties->get("host") + ':' + std::to_string(ep.port()); - m_ws.async_handshake( + std::string host_ = self->m_properties->get("host") + ':' + std::to_string(ep.port()); + self->m_ws.async_handshake( host_, - m_properties->get("route"), - [this]( beast::error_code ec ) { + self->m_properties->get("route"), + [self]( beast::error_code ec ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - m_on_open(); - m_ws.async_read( - m_buffer, - [ this ]( beast::error_code ec, std::size_t bytes_transfered ) - { handle_message( ec, bytes_transfered ); } + if( self->m_properties->get("binary") ) + self->m_ws.binary( true ); + self->m_on_open(); + self->m_ws.async_read( + self->m_buffer, + [ self ]( beast::error_code ec, std::size_t bytes_transfered ) + { self->handle_message( ec, bytes_transfered ); } ); } ); @@ -232,7 +259,7 @@ class WebsocketSessionTLS final: public WebsocketSession { WebsocketSessionTLS( net::io_context& ioc, ssl::context& ctx, - Dictionary* properties, + std::shared_ptr properties, void_cb& on_open, string_cb& on_fail, char_cb& on_message, @@ -251,73 +278,76 @@ class WebsocketSessionTLS final: public WebsocketSession { { } void run() override { + auto self = std::static_pointer_cast(this->shared_from_this()); m_resolver.async_resolve( m_properties->get("host").c_str(), m_properties->get("port").c_str(), - [this]( beast::error_code ec, tcp::resolver::results_type results ) { + [self]( beast::error_code ec, tcp::resolver::results_type results ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } // Set the timeout for the operation - beast::get_lowest_layer(m_ws).expires_after(std::chrono::seconds(5)); + beast::get_lowest_layer(self->m_ws).expires_after(std::chrono::seconds(5)); // Make the connection on the IP address we get from a lookup - beast::get_lowest_layer(m_ws).async_connect( + beast::get_lowest_layer(self->m_ws).async_connect( results, - [this]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) + [self]( beast::error_code ec, tcp::resolver::results_type::endpoint_type ep ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } if(! SSL_set_tlsext_host_name( - m_ws.next_layer().native_handle(), - m_properties->get("host").c_str())) + self->m_ws.next_layer().native_handle(), + self->m_properties->get("host").c_str())) { ec = beast::error_code(static_cast(::ERR_get_error()), net::error::get_ssl_category()); - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - m_complete_host = m_properties->get("host") + ':' + std::to_string(ep.port()); + self->m_complete_host = self->m_properties->get("host") + ':' + std::to_string(ep.port()); // ssl handler - m_ws.next_layer().async_handshake( + self->m_ws.next_layer().async_handshake( ssl::stream_base::client, - [this]( beast::error_code ec ) { + [self]( beast::error_code ec ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - beast::get_lowest_layer(m_ws).expires_never(); + beast::get_lowest_layer(self->m_ws).expires_never(); // Set suggested timeout settings for the websocket - m_ws.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + self->m_ws.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); // Set a decorator to change the User-Agent of the handshake - m_ws.set_option(websocket::stream_base::decorator( - [this](websocket::request_type& req) + self->m_ws.set_option(websocket::stream_base::decorator( + [self](websocket::request_type& req) { - set_headers(req); + self->set_headers(req); req.set(http::field::user_agent, "CSP WebsocketAdapter"); })); - m_ws.async_handshake( - m_complete_host, - m_properties->get("route"), - [this]( beast::error_code ec ) { + self->m_ws.async_handshake( + self->m_complete_host, + self->m_properties->get("route"), + [self]( beast::error_code ec ) { if(ec) { - m_on_fail(ec.message()); + self->m_on_fail(ec.message()); return; } - m_on_open(); - m_ws.async_read( - m_buffer, - [ this ]( beast::error_code ec, std::size_t bytes_transfered ) - { handle_message( ec, bytes_transfered ); } + if( self->m_properties->get("binary") ) + self->m_ws.binary( true ); + self->m_on_open(); + self->m_ws.async_read( + self->m_buffer, + [ self ]( beast::error_code ec, std::size_t bytes_transfered ) + { self->handle_message( ec, bytes_transfered ); } ); } ); @@ -340,23 +370,26 @@ class WebsocketSessionTLS final: public WebsocketSession { class WebsocketEndpoint { public: - WebsocketEndpoint( Dictionary properties ); - virtual ~WebsocketEndpoint() { }; + WebsocketEndpoint( net::io_context& ioc, Dictionary properties ); + ~WebsocketEndpoint(); void setOnOpen(void_cb on_open); void setOnFail(string_cb on_fail); void setOnMessage(char_cb on_message); void setOnClose(void_cb on_close); void setOnSendFail(string_cb on_send_fail); - Dictionary& getProperties(); + void updateHeaders(Dictionary properties); + void updateHeaders(const std::string& properties); + std::shared_ptr getProperties(); void run(); - void stop(); + void stop( bool stop_ioc = true); void send(const std::string& s); + void ping(); private: - Dictionary m_properties; - BaseWebsocketSession* m_session; - net::io_context m_ioc; + std::shared_ptr m_properties; + std::shared_ptr m_session; + net::io_context& m_ioc; void_cb m_on_open; string_cb m_on_fail; char_cb m_on_message; diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp new file mode 100644 index 000000000..d124f04f9 --- /dev/null +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.cpp @@ -0,0 +1,442 @@ +#include + +namespace csp { + +INIT_CSP_ENUM( adapters::websocket::ClientStatusType, + "ACTIVE", + "GENERIC_ERROR", + "CONNECTION_FAILED", + "CLOSED", + "MESSAGE_SEND_FAIL", +); + +} +namespace csp::adapters::websocket { + +WebsocketEndpointManager::WebsocketEndpointManager( ClientAdapterManager* mgr, const Dictionary & properties, Engine* engine ) +: m_num_threads( static_cast(properties.get("num_threads")) ), + m_ioc( m_num_threads ), + m_engine( engine ), + m_strand( boost::asio::make_strand(m_ioc) ), + m_mgr( mgr ), + m_updateAdapter( nullptr ), + m_properties( properties ), + m_work_guard(boost::asio::make_work_guard(m_ioc)), + m_dynamic( properties.get("dynamic") ){ + // Total number of subscribe and send function calls, set on the adapter manager + // when is it created. Note, that some of the input adapters might have been + // pruned from the graph and won't get created. + auto input_size = static_cast(properties.get("subscribe_calls")); + m_inputAdapters.resize(input_size, nullptr); + m_consumer_endpoints.resize(input_size); + // send_calls + auto output_size = static_cast(properties.get("send_calls")); + m_outputAdapters.resize(output_size, nullptr); + m_producer_endpoints.resize(output_size); + + // We choose to not automatically size m_connectionRequestAdapters + // since the index there is not meaningful, + // producers and subscribers are combined. + // We just hold onto their pointers. +}; + +WebsocketEndpointManager::~WebsocketEndpointManager() +{ +} + +void WebsocketEndpointManager::start(DateTime starttime, DateTime endtime) { + m_ioc.restart(); + if( !m_dynamic ){ + boost::asio::post(m_strand, [this]() { + // We subscribe for both the subscribe and send calls + // But we probably should check here. + if( m_outputAdapters.size() == 1) + handleConnectionRequest(Dictionary(m_properties), 0, false); + // If we have an input adapter call AND it's not pruned. + if( m_inputAdapters.size() == 1 && !adapterPruned(0)) + handleConnectionRequest(Dictionary(m_properties), 0, true); + }); + } + for (size_t i = 0; i < m_num_threads; ++i) { + m_threads.emplace_back(std::make_unique([this]() { + m_ioc.run(); + })); + } +}; + +bool WebsocketEndpointManager::adapterPruned( size_t caller_id ){ + return m_inputAdapters[caller_id] == nullptr; +}; + +void WebsocketEndpointManager::send(const std::string& value, const size_t& caller_id) { + const auto& endpoints = m_producer_endpoints[caller_id]; + // For each endpoint this producer is connected to + for (const auto& endpoint_id : endpoints) { + // Double check the endpoint exists and producer is still valid + if(publishesToEndpoint(caller_id, endpoint_id)) { + auto it = m_endpoints.find(endpoint_id); + if( it != m_endpoints.end()) + it->second.get()->send(value); + } + } +}; + +void WebsocketEndpointManager::removeEndpointForCallerId(const std::string& endpoint_id, bool is_consumer, size_t validated_id) +{ + if (is_consumer) { + WebsocketEndpointManager::removeConsumer(endpoint_id, validated_id); + } else { + WebsocketEndpointManager::removeProducer(endpoint_id, validated_id); + } + if (canRemoveEndpoint(endpoint_id)) + shutdownEndpoint(endpoint_id); +} + +void WebsocketEndpointManager::shutdownEndpoint(const std::string& endpoint_id) { + // This functions should only be called from the thread running m_ioc + // Cancel any pending reconnection attempts + if (auto config_it = m_endpoint_configs.find(endpoint_id); + config_it != m_endpoint_configs.end()) { + config_it->second.reconnect_timer->cancel(); + m_endpoint_configs.erase(config_it); + } + + // Stop and remove the endpoint + // No need to stop, destructo handles it + if (auto endpoint_it = m_endpoints.find(endpoint_id); endpoint_it != m_endpoints.end()) + m_endpoints.erase(endpoint_it); + std::stringstream ss; + ss << "No more connections for endpoint={" << endpoint_id << "} Shutting down..."; + std::string msg = ss.str(); + m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::CLOSED, msg); +} + +void WebsocketEndpointManager::setupEndpoint(const std::string& endpoint_id, + std::unique_ptr endpoint, + std::string payload, + bool persist, + bool is_consumer, + size_t validated_id) +{ + // Store the endpoint first + auto& stored_endpoint = m_endpoints[endpoint_id] = std::move(endpoint); + + stored_endpoint->setOnOpen([this, endpoint_id, endpoint = stored_endpoint.get(), payload=std::move(payload), persist, is_consumer, validated_id]() { + auto [iter, inserted] = m_endpoint_configs.try_emplace(endpoint_id, m_ioc); + auto& config = iter->second; + config.connected = true; + config.attempting_reconnect = false; + + // Send consumer payloads + const auto& consumers = m_endpoint_consumers[endpoint_id]; + for (size_t i = 0; i < config.consumer_payloads.size(); ++i) { + if (!config.consumer_payloads[i].empty() && + i < consumers.size() && consumers[i]) { + endpoint->send(config.consumer_payloads[i]); + } + } + + // Send producer payloads + const auto& producers = m_endpoint_producers[endpoint_id]; + for (size_t i = 0; i < config.producer_payloads.size(); ++i) { + if (!config.producer_payloads[i].empty() && + i < producers.size() && producers[i]) { + endpoint->send(config.producer_payloads[i]); + } + } + // should only happen if persist is False + if ( !payload.empty() ) + endpoint -> send(payload); + std::stringstream ss; + ss << "Connected successfully for endpoint={" << endpoint_id << "}"; + std::string msg = ss.str(); + m_mgr -> pushStatus(StatusLevel::INFO, ClientStatusType::ACTIVE, msg); + // We remove the caller id, if it was the only one, then we shut down the endpoint + if( !persist ) + removeEndpointForCallerId(endpoint_id, is_consumer, validated_id); + }); + + stored_endpoint->setOnFail([this, endpoint_id](const std::string& reason) { + handleEndpointFailure(endpoint_id, reason, ClientStatusType::CONNECTION_FAILED); + }); + + stored_endpoint->setOnClose([this, endpoint_id]() { + // If we didn't close it ourselves + if (auto config_it = m_endpoint_configs.find(endpoint_id); config_it != m_endpoint_configs.end()) + handleEndpointFailure(endpoint_id, "Connection closed", ClientStatusType::CLOSED); + }); + stored_endpoint->setOnMessage([this, endpoint_id](void* data, size_t len) { + // Here we need to route to all active consumers for this endpoint + const auto& consumers = m_endpoint_consumers[endpoint_id]; + + // For each active consumer, we need to send to their input adapter + PushBatch batch( m_engine -> rootEngine() ); // TODO is this right? + for (size_t consumer_id = 0; consumer_id < consumers.size(); ++consumer_id) { + if (consumers[consumer_id]) { + std::vector data_copy(static_cast(data), + static_cast(data) + len); + auto tup = std::tuple{endpoint_id, data_copy.data()}; + m_inputAdapters[consumer_id] -> processMessage( endpoint_id, data_copy.data(), len, &batch ); + } + } + }); + stored_endpoint -> setOnSendFail( + [ this, endpoint_id ]( const std::string& s ) { + std::stringstream ss; + ss << "Error: " << s << " for endpoint={" << endpoint_id << "}"; + std::string msg = ss.str(); + m_mgr -> pushStatus( StatusLevel::ERROR, ClientStatusType::MESSAGE_SEND_FAIL, msg ); + } + ); + stored_endpoint -> run(); +}; + + +void WebsocketEndpointManager::handleEndpointFailure(const std::string& endpoint_id, + const std::string& reason, ClientStatusType status_type) { + // If there are any active consumers/producers, try to reconnect + if (!canRemoveEndpoint(endpoint_id)) { + auto [iter, inserted] = m_endpoint_configs.try_emplace(endpoint_id, m_ioc); + auto& config = iter->second; + config.connected = false; + + if (!config.attempting_reconnect) { + config.attempting_reconnect = true; + + // Schedule reconnection attempt + config.reconnect_timer->expires_after(config.reconnect_interval); + config.reconnect_timer->async_wait([this, endpoint_id](const error_code& ec) { + // boost::asio::post(m_ioc, [this, endpoint_id]() { + // If we still want to subscribe to this endpoint + if (auto it = m_endpoints.find(endpoint_id); + it != m_endpoints.end()) { + auto config_it = m_endpoint_configs.find(endpoint_id); + if (config_it != m_endpoint_configs.end()) { + auto& config = config_it -> second; + // We are no longer attempting to reconnect + config.attempting_reconnect = false; + } + it->second->run(); // Attempt to reconnect + } + }); + } + } else { + // No active consumers/producers, clean up the endpoint + m_endpoints.erase(endpoint_id); + m_endpoint_configs.erase(endpoint_id); + } + + std::stringstream ss; + ss << "Connection Failure for endpoint={" << endpoint_id << "} Due to: " << reason; + std::string msg = ss.str(); + if ( status_type == ClientStatusType::CLOSED || status_type == ClientStatusType::ACTIVE ) + m_mgr -> pushStatus(StatusLevel::INFO, status_type, msg); + else{ + m_mgr -> pushStatus(StatusLevel::ERROR, status_type, msg); + } +}; + +void WebsocketEndpointManager::handleConnectionRequest(const Dictionary & properties, size_t validated_id, bool is_subscribe) +{ + // This should only get called from the thread running + // m_ioc. This allows us to avoid locks on internal data + // structures + auto endpoint_id = properties.get("uri"); + autogen::ActionType action = autogen::ActionType::create( properties.get("action") ); + switch(action.enum_value()) { + case autogen::ActionType::enum_::CONNECT: { + auto persistent = properties.get("persistent"); + auto reconnect_interval = properties.get("reconnect_interval"); + // Update endpoint config + auto& config = m_endpoint_configs.try_emplace(endpoint_id, m_ioc).first->second; + + config.reconnect_interval = std::chrono::milliseconds( + reconnect_interval.asMilliseconds() + ); + std::string payload = ""; + bool has_payload = properties.tryGet("on_connect_payload", payload); + + if (has_payload && !payload.empty() && persistent) { + auto& payloads = is_subscribe ? config.consumer_payloads : config.producer_payloads; + if (payloads.size() <= validated_id) { + payloads.resize(validated_id + 1); + } + payloads[validated_id] = std::move(payload); // Move to config + } + + if ( persistent ){ + if (is_subscribe) { + WebsocketEndpointManager::addConsumer(endpoint_id, validated_id); + } else { + WebsocketEndpointManager::addProducer(endpoint_id, validated_id); + } + } + + bool is_new_endpoint = !m_endpoints.contains(endpoint_id); + if (is_new_endpoint) { + auto endpoint = std::make_unique(m_ioc, properties); + // We can safely move payload regardless - if it was never written to, it's just an empty string + WebsocketEndpointManager::setupEndpoint(endpoint_id, std::move(endpoint), + (has_payload && !payload.empty() && persistent) ? "" : std::move(payload), + persistent, is_subscribe, validated_id ); + } + else{ + if( !persistent && !payload.empty() ) + m_endpoints[endpoint_id]->send(payload); + // Conscious decision to let non-persisten connection + // results to update the header + auto headers = properties.get("headers"); + m_endpoints[endpoint_id]->updateHeaders(std::move(headers)); + } + break; + } + + case csp::autogen::ActionType::enum_::DISCONNECT: { + // Clear persistence flag for this caller + removeEndpointForCallerId(endpoint_id, is_subscribe, validated_id); + break; + } + + case csp::autogen::ActionType::enum_::PING: { + // Only ping if the caller is actually connected to this endpoint + auto& consumers = m_endpoint_consumers[endpoint_id]; + auto& producers = m_endpoint_producers[endpoint_id]; + + if ( ( is_subscribe && validated_id < consumers.size() && consumers[validated_id] ) || + ( !is_subscribe && validated_id < producers.size() && producers[validated_id] ) ) { + if (auto it = m_endpoints.find(endpoint_id); it != m_endpoints.end()) { + it->second.get()->ping(); + } + } + break; + } + } +}; + +WebsocketEndpoint * WebsocketEndpointManager::getNonDynamicEndpoint(){ + // Should only be called if dynamic = False + if (!m_endpoints.empty()) { + return m_endpoints.begin()->second.get(); + } + return nullptr; +} + +void WebsocketEndpointManager::addConsumer(const std::string& endpoint_id, size_t caller_id) { + ensureVectorSize(m_endpoint_consumers[endpoint_id], caller_id); + m_endpoint_consumers[endpoint_id][caller_id] = true; + + m_consumer_endpoints[caller_id].insert(endpoint_id); +}; + +void WebsocketEndpointManager::addProducer(const std::string& endpoint_id, size_t caller_id) { + ensureVectorSize(m_endpoint_producers[endpoint_id], caller_id); + m_endpoint_producers[endpoint_id][caller_id] = true; + + m_producer_endpoints[caller_id].insert(endpoint_id); +}; + +bool WebsocketEndpointManager::canRemoveEndpoint(const std::string& endpoint_id) { + const auto& consumers = m_endpoint_consumers[endpoint_id]; + const auto& producers = m_endpoint_producers[endpoint_id]; + + // Check if any true values exist in either vector + return std::none_of(consumers.begin(), consumers.end(), [](bool b) { return b; }) && + std::none_of(producers.begin(), producers.end(), [](bool b) { return b; }); +}; + +void WebsocketEndpointManager::removeConsumer(const std::string& endpoint_id, size_t caller_id) { + auto& consumers = m_endpoint_consumers[endpoint_id]; + // Possibility it might not be subscribed, + // so we have this check. + if (caller_id < consumers.size()) { + consumers[caller_id] = false; + } + // We initialize these upfront, this will be valid. + m_consumer_endpoints[caller_id].erase(endpoint_id); +}; + +void WebsocketEndpointManager::removeProducer(const std::string& endpoint_id, size_t caller_id) { + auto& producers = m_endpoint_producers[endpoint_id]; + // Possibility it might not be publihsing to + // so we have this check. + if (caller_id < producers.size()) { + producers[caller_id] = false; + } + + // We initialize these upfront, this will be valid. + m_producer_endpoints[caller_id].erase(endpoint_id); +}; + + +void WebsocketEndpointManager::stop() { + // Stop all endpoints + // Endpoints running on m_ioc thread, + // So we call stop there + boost::asio::post(m_strand, [this]() { + for (auto& [endpoint_id, _] : m_endpoints) { + shutdownEndpoint(endpoint_id); + } + }); + // Stop the work guard to allow the io_context to complete + m_work_guard.reset(); + m_ioc.stop(); + + // Wait for all threads to finish + for (auto& thread : m_threads) { + if (thread && thread->joinable()) { + thread->join(); + } + } + + // Clear threads before other members are destroyed + m_threads.clear(); +}; + +PushInputAdapter* WebsocketEndpointManager::getInputAdapter(CspTypePtr & type, PushMode pushMode, const Dictionary & properties) +{ + auto caller_id = properties.get("caller_id"); + size_t validated_id = validateCallerId(caller_id); + auto input_adapter = m_engine -> createOwnedObject( + type, + pushMode, + properties, + m_dynamic + ); + m_inputAdapters[validated_id] = input_adapter; + return m_inputAdapters[validated_id]; +}; + +OutputAdapter* WebsocketEndpointManager::getOutputAdapter( const Dictionary & properties ) +{ + auto caller_id = properties.get("caller_id"); + size_t validated_id = validateCallerId(caller_id); + assert(!properties.get("is_subscribe")); + assert(m_outputAdapters.size() == validated_id); + + auto output_adapter = m_engine -> createOwnedObject( this, validated_id, m_ioc, m_strand ); + m_outputAdapters[validated_id] = output_adapter; + return m_outputAdapters[validated_id]; +}; + +OutputAdapter * WebsocketEndpointManager::getHeaderUpdateAdapter() +{ + if (m_updateAdapter == nullptr) + m_updateAdapter = m_engine -> createOwnedObject( this, m_strand ); + + return m_updateAdapter; +}; + +OutputAdapter * WebsocketEndpointManager::getConnectionRequestAdapter( const Dictionary & properties ) +{ + auto caller_id = properties.get("caller_id"); + auto is_subscribe = properties.get("is_subscribe"); + + auto* adapter = m_engine->createOwnedObject( + this, is_subscribe, caller_id, m_strand + ); + m_connectionRequestAdapters.push_back(adapter); + + return adapter; +}; + +} diff --git a/cpp/csp/adapters/websocket/WebsocketEndpointManager.h b/cpp/csp/adapters/websocket/WebsocketEndpointManager.h new file mode 100644 index 000000000..9007251ed --- /dev/null +++ b/cpp/csp/adapters/websocket/WebsocketEndpointManager.h @@ -0,0 +1,170 @@ +#ifndef WEBSOCKET_ENDPOINT_MANAGER_H +#define WEBSOCKET_ENDPOINT_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace csp::adapters::websocket { +using namespace csp; +class WebsocketEndpoint; + +class ClientAdapterManager; +class ClientOutputAdapter; +class ClientConnectionRequestAdapter; +class ClientHeaderUpdateOutputAdapter; + +struct ConnectPayloads { + std::vector consumer_payloads; + std::vector producer_payloads; +}; + +struct EndpointConfig { + std::chrono::milliseconds reconnect_interval; + std::unique_ptr reconnect_timer; + bool attempting_reconnect{false}; + bool connected{false}; + + // Payloads for different client types + std::vector consumer_payloads; + std::vector producer_payloads; + + explicit EndpointConfig(boost::asio::io_context& ioc) + : reconnect_timer(std::make_unique(ioc)) {} +}; + +// Callbacks for endpoint events +struct EndpointCallbacks { + std::function onOpen; + std::function onFail; + std::function onClose; + std::function onSendFail; + std::function onMessage; +}; + +struct WebsocketClientStatusTypeTraits +{ + enum _enum : unsigned char + { + ACTIVE = 0, + GENERIC_ERROR = 1, + CONNECTION_FAILED = 2, + CLOSED = 3, + MESSAGE_SEND_FAIL = 4, + + NUM_TYPES + }; + +protected: + _enum m_value; +}; + +using ClientStatusType = Enum; + +class WebsocketEndpointManager { +public: + explicit WebsocketEndpointManager(ClientAdapterManager* mgr, const Dictionary & properties, Engine* engine); + ~WebsocketEndpointManager(); + void send(const std::string& value, const size_t& caller_id); + // Whether the input adapter (subscribe) given by a specific caller_id was pruned + bool adapterPruned( size_t caller_id ); + // Whether the output adapater (publish) given by a specific caller_id publishes to a given endpoint + + void start(DateTime starttime, DateTime endtime); + void stop(); + + void handleConnectionRequest(const Dictionary & properties, size_t validated_id, bool is_subscribe); + void handleEndpointFailure(const std::string& endpoint_id, const std::string& reason, ClientStatusType status_type); + + void setupEndpoint(const std::string& endpoint_id, std::unique_ptr endpoint, std::string payload, bool persist, bool is_consumer, size_t validated_id); + void shutdownEndpoint(const std::string& endpoint_id); + + void addConsumer(const std::string& endpoint_id, size_t caller_id); + void addProducer(const std::string& endpoint_id, size_t caller_id); + bool canRemoveEndpoint(const std::string& endpoint_id); + + void removeEndpointForCallerId(const std::string& endpoint_id, bool is_consumer, size_t validated_id); + void removeConsumer(const std::string& endpoint_id, size_t caller_id); + void removeProducer(const std::string& endpoint_id, size_t caller_id); + + WebsocketEndpoint * getNonDynamicEndpoint(); + PushInputAdapter * getInputAdapter( CspTypePtr & type, PushMode pushMode, const Dictionary & properties ); + OutputAdapter * getOutputAdapter( const Dictionary & properties ); + OutputAdapter * getHeaderUpdateAdapter(); + OutputAdapter * getConnectionRequestAdapter( const Dictionary & properties ); +private: + inline size_t validateCallerId(int64_t caller_id) const { + if (caller_id < 0) { + CSP_THROW(ValueError, "caller_id cannot be negative: " << caller_id); + } + return static_cast(caller_id); + } + inline void ensureVectorSize(std::vector& vec, size_t caller_id) { + if (vec.size() <= caller_id) { + vec.resize(caller_id + 1, false); + } + } + // Whether the output adapater (publish) given by a specific caller_id publishes to a given endpoint + inline bool publishesToEndpoint(const size_t caller_id, const std::string& endpoint_id){ + auto config_it = m_endpoint_configs.find(endpoint_id); + if( config_it == m_endpoint_configs.end() || !config_it->second.connected ) + return false; + + return caller_id < m_endpoint_producers[endpoint_id].size() && + m_endpoint_producers[endpoint_id][caller_id]; + } + size_t m_num_threads; + net::io_context m_ioc; + Engine* m_engine; + boost::asio::strand m_strand; + ClientAdapterManager* m_mgr; + ClientHeaderUpdateOutputAdapter* m_updateAdapter; + std::vector> m_threads; + Dictionary m_properties; + std::vector m_connectionRequestAdapters; + + // Bidirectional mapping using vectors since caller_ids are sequential + // Maybe not efficient? Should be good for small number of edges though + std::unordered_map> m_endpoint_consumers; // endpoint_id -> vector[caller_id] for consuemrs + std::unordered_map> m_endpoint_producers; // endpoint_id -> vector[caller_id] for producers + + // Quick lookup for caller's endpoints + std::vector< std::unordered_set > m_consumer_endpoints; // caller_id -> set of endpoints they consume from + std::vector< std::unordered_set > m_producer_endpoints; // caller_id -> set of endpoints they produce to + boost::asio::executor_work_guard m_work_guard; + std::unordered_map> m_endpoints; + std::unordered_map m_endpoint_configs; + std::vector m_inputAdapters; + std::vector m_outputAdapters; + bool m_dynamic; +}; + +} +#endif \ No newline at end of file diff --git a/cpp/csp/python/adapters/websocketadapterimpl.cpp b/cpp/csp/python/adapters/websocketadapterimpl.cpp index d636932da..557ee4454 100644 --- a/cpp/csp/python/adapters/websocketadapterimpl.cpp +++ b/cpp/csp/python/adapters/websocketadapterimpl.cpp @@ -45,7 +45,11 @@ static OutputAdapter * create_websocket_output_adapter( csp::AdapterManager * ma auto * websocketManager = dynamic_cast( manager ); if( !websocketManager ) CSP_THROW( TypeError, "Expected WebsocketClientAdapterManager" ); - return websocketManager -> getOutputAdapter(); + PyObject * pyProperties; + if( !PyArg_ParseTuple( args, "O!", + &PyDict_Type, &pyProperties ) ) + CSP_THROW( PythonPassthrough, "" ); + return websocketManager -> getOutputAdapter(fromPython( pyProperties )); } static OutputAdapter * create_websocket_header_update_adapter( csp::AdapterManager * manager, PyEngine * pyengine, PyObject * args ) @@ -56,10 +60,25 @@ static OutputAdapter * create_websocket_header_update_adapter( csp::AdapterManag return websocketManager -> getHeaderUpdateAdapter(); } +static OutputAdapter * create_websocket_connection_request_adapter( csp::AdapterManager * manager, PyEngine * pyengine, PyObject * args ) +{ + PyObject * pyProperties; + auto * websocketManager = dynamic_cast( manager ); + if( !websocketManager ) + CSP_THROW( TypeError, "Expected WebsocketClientAdapterManager" ); + + if( !PyArg_ParseTuple( args, "O!", + &PyDict_Type, &pyProperties ) ) + CSP_THROW( PythonPassthrough, "" ); + return websocketManager -> getConnectionRequestAdapter( fromPython( pyProperties ) ); +} + REGISTER_ADAPTER_MANAGER( _websocket_adapter_manager, create_websocket_adapter_manager ); REGISTER_INPUT_ADAPTER( _websocket_input_adapter, create_websocket_input_adapter ); REGISTER_OUTPUT_ADAPTER( _websocket_output_adapter, create_websocket_output_adapter ); REGISTER_OUTPUT_ADAPTER( _websocket_header_update_adapter, create_websocket_header_update_adapter); +REGISTER_OUTPUT_ADAPTER( _websocket_connection_request_adapter, create_websocket_connection_request_adapter); + static PyModuleDef _websocketadapterimpl_module = { PyModuleDef_HEAD_INIT, diff --git a/csp/adapters/dynamic_adapter_utils.py b/csp/adapters/dynamic_adapter_utils.py new file mode 100644 index 000000000..1e8d1ae3a --- /dev/null +++ b/csp/adapters/dynamic_adapter_utils.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, NonNegativeInt + + +class AdapterInfo(BaseModel): + caller_id: NonNegativeInt + is_subscribe: bool diff --git a/csp/adapters/websocket.py b/csp/adapters/websocket.py index 2ba305c9a..1a88f48c2 100644 --- a/csp/adapters/websocket.py +++ b/csp/adapters/websocket.py @@ -8,6 +8,7 @@ import csp from csp import ts +from csp.adapters.dynamic_adapter_utils import AdapterInfo from csp.adapters.status import Status from csp.adapters.utils import ( BytesMessageProtoMapper, @@ -17,18 +18,26 @@ RawBytesMessageMapper, RawTextMessageMapper, ) +from csp.adapters.websocket_types import ( + ActionType, + ConnectionRequest, + InternalConnectionRequest, + WebsocketHeaderUpdate, + WebsocketStatus, +) from csp.impl.wiring import input_adapter_def, output_adapter_def, status_adapter_def from csp.impl.wiring.delayed_node import DelayedNodeWrapperDef from csp.lib import _websocketadapterimpl -from .websocket_types import WebsocketHeaderUpdate - +# InternalConnectionRequest, _ = ( + ActionType, BytesMessageProtoMapper, DateTimeType, JSONTextMessageMapper, RawBytesMessageMapper, RawTextMessageMapper, + WebsocketStatus, ) T = TypeVar("T") @@ -59,6 +68,12 @@ def diff_dict(old, new): return d +def _sanitize_port(uri: str, port): + if port: + return str(port) + return "443" if uri.startswith("wss") else "80" + + class TableManager: def __init__(self, tables, delta_updates): self._tables = tables @@ -237,7 +252,7 @@ def on_close(self): self._manager.unsubscribe(self) def on_message(self, message): - logging.info("got message %r", message) + logging.warning("got message %r", message) # TODO Ignore for now # parsed = rapidjson.loads(message) @@ -388,11 +403,22 @@ def _instantiate(self): class WebsocketAdapterManager: + """ + Can subscribe dynamically via ts[List[ConnectionRequest]] + + We use a ts[List[ConnectionRequest]] to allow users to submit a batch of conneciton requests in + a single engine cycle. + """ + def __init__( self, - uri: str, + uri: Optional[str] = None, reconnect_interval: timedelta = timedelta(seconds=2), - headers: Dict[str, str] = None, + headers: Optional[Dict[str, str]] = None, + dynamic: bool = False, + connection_request: Optional[ConnectionRequest] = None, + num_threads: int = 1, + binary: bool = False, ): """ uri: str @@ -401,26 +427,86 @@ def __init__( time interval to wait before trying to reconnect (must be >= 1 second) headers: Dict[str, str] = None headers to apply to the request during the handshake + dynamic: bool = False + Whether we accept dynamically altering the connections via ConnectionRequest objects. + num_threads: int = 1 + Determines number of threads to allocate for running the websocket endpoints. + Defaults to 1 to avoid thread switching + binary: bool = False + Whether to send/receive text or binary data """ + + self._properties = dict(dynamic=dynamic, num_threads=num_threads, binary=binary) + # Enumerating for clarity + if connection_request is not None and uri is not None: + raise ValueError("'connection_request' cannot be set along with 'uri'") + + # Exactly 1 of connection_request and uri is None + if connection_request is not None or uri is not None: + if connection_request is None: + connection_request = ConnectionRequest( + uri=uri, reconnect_interval=reconnect_interval, headers=headers or {} + ) + self._properties.update(self._get_properties(connection_request).to_dict()) + + # This is a counter that will be used to identify every function call + # We keep track of the subscribes and sends separately + self._subscribe_call_id = 0 + self._send_call_id = 0 + + # This maps types to their wrapper structs + self._wrapper_struct_dict = {} + + @property + def _dynamic(self): + return self._properties.get("dynamic", False) + + def _get_properties(self, conn_request: ConnectionRequest) -> InternalConnectionRequest: + uri = conn_request.uri + reconnect_interval = conn_request.reconnect_interval + assert reconnect_interval >= timedelta(seconds=1) resp = urllib.parse.urlparse(uri) if resp.hostname is None: raise ValueError(f"Failed to parse host from URI: {uri}") - self._properties = dict( + res = InternalConnectionRequest( host=resp.hostname, # if no port is explicitly present in the uri, the resp.port is None - port=self._sanitize_port(uri, resp.port), + port=_sanitize_port(uri, resp.port), route=resp.path or "/", # resource shouldn't be empty string use_ssl=uri.startswith("wss"), reconnect_interval=reconnect_interval, - headers=headers if headers else {}, + headers=rapidjson.dumps(conn_request.headers) if conn_request.headers else "", + persistent=conn_request.persistent, + action=conn_request.action.name, + on_connect_payload=conn_request.on_connect_payload, + uri=uri, + dynamic=self._dynamic, + binary=self._properties.get("binary", False), ) + return res - def _sanitize_port(self, uri: str, port): - if port: - return str(port) - return "443" if uri.startswith("wss") else "80" + def _get_caller_id(self, is_subscribe: bool) -> int: + if is_subscribe: + caller_id = self._subscribe_call_id + self._subscribe_call_id += 1 + else: + caller_id = self._send_call_id + self._send_call_id += 1 + return caller_id + + def get_wrapper_struct(self, ts_type: type): + if (dynamic_type := self._wrapper_struct_dict.get(ts_type)) is None: + # I want to preserve type information + # Not sure a better way to do this + class CustomWrapperStruct(csp.Struct): + msg: ts_type # noqa + uri: str + + dynamic_type = CustomWrapperStruct + self._wrapper_struct_dict[ts_type] = dynamic_type + return dynamic_type def subscribe( self, @@ -429,7 +515,29 @@ def subscribe( field_map: Union[dict, str] = None, meta_field_map: dict = None, push_mode: csp.PushMode = csp.PushMode.NON_COLLAPSING, + connection_request: Optional[ts[List[ConnectionRequest]]] = None, ): + """If dynamic is True, this will tick a custom WrapperStruct, + with 'msg' as the correct type of the message. + And 'uri' that specifies the 'uri' the message comes from. + + Otherwise, returns just message. + + ts_type should be original type!! The tuple wrapping happens + automatically + """ + caller_id = self._get_caller_id(is_subscribe=True) + # Gives validation, more to start defining a common interface + adapter_props = AdapterInfo(caller_id=caller_id, is_subscribe=True).model_dump() + connection_request = csp.null_ts(List[ConnectionRequest]) if connection_request is None else connection_request + request_dict = csp.apply( + connection_request, + lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], + List[InternalConnectionRequest], + ) + # Output adapter to handle connection requests + _websocket_connection_request_adapter_def(self, request_dict, adapter_props) + field_map = field_map or {} meta_field_map = meta_field_map or {} if isinstance(field_map, str): @@ -442,20 +550,37 @@ def subscribe( properties["field_map"] = field_map properties["meta_field_map"] = meta_field_map + properties.update(adapter_props) + # We wrap the message in a struct to note the url it comes from + if self._dynamic: + ts_type = self.get_wrapper_struct(ts_type=ts_type) return _websocket_input_adapter_def(self, ts_type, properties, push_mode=push_mode) - def send(self, x: ts["T"]): - return _websocket_output_adapter_def(self, x) + def send(self, x: ts["T"], connection_request: Optional[ts[List[ConnectionRequest]]] = None): + caller_id = self._get_caller_id(is_subscribe=False) + # Gives validation, more to start defining a common interface + adapter_props = AdapterInfo(caller_id=caller_id, is_subscribe=False).model_dump() + connection_request = csp.null_ts(List[ConnectionRequest]) if connection_request is None else connection_request + request_dict = csp.apply( + connection_request, + lambda conn_reqs: [self._get_properties(conn_req) for conn_req in conn_reqs], + List[InternalConnectionRequest], + ) + _websocket_connection_request_adapter_def(self, request_dict, adapter_props) + return _websocket_output_adapter_def(self, x, adapter_props) def update_headers(self, x: ts[List[WebsocketHeaderUpdate]]): + if self._dynamic: + raise ValueError("If dynamic, cannot call update_headers") return _websocket_header_update_adapter_def(self, x) def status(self, push_mode=csp.PushMode.NON_COLLAPSING): ts_type = Status - return status_adapter_def(self, ts_type, push_mode=push_mode) + return status_adapter_def(self, ts_type, push_mode) def _create(self, engine, memo): """method needs to return the wrapped c++ adapter manager""" + self._properties.update({"subscribe_calls": self._subscribe_call_id, "send_calls": self._send_call_id}) return _websocketadapterimpl._websocket_adapter_manager(engine, self._properties) @@ -473,6 +598,7 @@ def _create(self, engine, memo): _websocketadapterimpl._websocket_output_adapter, WebsocketAdapterManager, input=ts["T"], + properties=dict, ) _websocket_header_update_adapter_def = output_adapter_def( @@ -481,3 +607,11 @@ def _create(self, engine, memo): WebsocketAdapterManager, input=ts[List[WebsocketHeaderUpdate]], ) + +_websocket_connection_request_adapter_def = output_adapter_def( + "websocket_connection_request_adapter", + _websocketadapterimpl._websocket_connection_request_adapter, + WebsocketAdapterManager, + input=ts[List[InternalConnectionRequest]], # needed, List[dict] didn't work on c++ level + properties=dict, +) diff --git a/csp/adapters/websocket_types.py b/csp/adapters/websocket_types.py index 710610501..542a1c205 100644 --- a/csp/adapters/websocket_types.py +++ b/csp/adapters/websocket_types.py @@ -1,3 +1,6 @@ +from datetime import timedelta +from typing import Dict + from csp.impl.enum import Enum from csp.impl.struct import Struct @@ -12,6 +15,44 @@ class WebsocketStatus(Enum): MESSAGE_SEND_FAIL = 4 +class ActionType(Enum): + CONNECT = 0 + DISCONNECT = 1 + PING = 2 + + class WebsocketHeaderUpdate(Struct): key: str value: str + + +class ConnectionRequest(Struct): + uri: str + action: ActionType = ActionType.CONNECT # Connect, Disconnect, Ping, etc + # Whether we maintain the connection + persistent: bool = True # Only relevant for Connect requests + reconnect_interval: timedelta = timedelta(seconds=2) + on_connect_payload: str = "" # message to send on connect + headers: Dict[str, str] = {} + + +# Only used internally +class InternalConnectionRequest(Struct): + host: str # Hostname parsed from the URI + port: str # Port number for the connection (parsed and sanitized from URI) + route: str # Resource path from URI, defaults to "/" if empty + uri: str # Complete original URI string + + # Connection behavior + use_ssl: bool # Whether to use secure WebSocket (wss://) + reconnect_interval: timedelta # Time to wait between reconnection attempts + persistent: bool # Whether to maintain a persistent connection + + # Headers and payloads + headers: str # HTTP headers for the connection as json string + on_connect_payload: str # Message to send when connection is established + + # Connection metadata + action: str # Connection action type (Connect, Disconnect, Ping, etc) + dynamic: bool # Whether the connection is dynamic + binary: bool # Whether to use binary mode for the connection diff --git a/csp/tests/adapters/test_websocket.py b/csp/tests/adapters/test_websocket.py index 4a33a3e12..2a013e240 100644 --- a/csp/tests/adapters/test_websocket.py +++ b/csp/tests/adapters/test_websocket.py @@ -1,41 +1,102 @@ import os +import pytest import pytz import threading import tornado.ioloop import tornado.web import tornado.websocket -import unittest +from contextlib import contextmanager from datetime import datetime, timedelta from tornado.testing import bind_unused_port from typing import List import csp from csp import ts -from csp.adapters.websocket import JSONTextMessageMapper, RawTextMessageMapper, Status, WebsocketAdapterManager +from csp.adapters.websocket import ( + ActionType, + ConnectionRequest, + JSONTextMessageMapper, + RawTextMessageMapper, + Status, + WebsocketAdapterManager, + WebsocketHeaderUpdate, + WebsocketStatus, +) class EchoWebsocketHandler(tornado.websocket.WebSocketHandler): def on_message(self, msg): + # Carve-out to allow inspecting the headers + if msg == "header1": + msg = self.request.headers.get(msg, "") + elif not isinstance(msg, str) and msg.decode("utf-8") == "header1": + # Need this for bytes + msg = self.request.headers.get("header1", "") return self.write_message(msg) -class TestWebsocket(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) +@contextmanager +def create_tornado_server(port: int = None): + """Base context manager for creating a Tornado server in a thread + + Args: + port: Optional port number. If None, an unused port will be chosen. + + Returns: + Tuple containing (io_loop, app, io_thread, port) + """ + ready_event = threading.Event() + io_loop = None + app = None + io_thread = None + + # Get an unused port if none specified + if port is None: sock, port = bind_unused_port() sock.close() - cls.port = port - cls.app.listen(port) - cls.io_loop = tornado.ioloop.IOLoop.current() - cls.io_thread = threading.Thread(target=cls.io_loop.start) - cls.io_thread.start() - - @classmethod - def tearDownClass(cls): - cls.io_loop.add_callback(cls.io_loop.stop) - if cls.io_thread: - cls.io_thread.join() + + def run_io_loop(): + nonlocal io_loop, app + io_loop = tornado.ioloop.IOLoop() + app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) + try: + app.listen(port) + ready_event.set() + io_loop.start() + except Exception as e: + ready_event.set() # Ensure we don't hang in case of error + raise + + io_thread = threading.Thread(target=run_io_loop) + io_thread.start() + ready_event.wait() + + try: + yield io_loop, app, io_thread, port + finally: + io_loop.add_callback(io_loop.stop) + if io_thread: + io_thread.join(timeout=5) + if io_thread.is_alive(): + raise RuntimeError("IOLoop failed to stop") + + +@contextmanager +def tornado_server(): + """Simplified context manager that uses the base implementation with dynamic port""" + with create_tornado_server() as (_io_loop, _app, _io_thread, port): + yield port + + +class TestWebsocket: + @pytest.fixture(scope="class", autouse=True) + def setup_tornado(self, request): + with create_tornado_server() as (io_loop, app, io_thread, port): + request.cls.io_loop = io_loop + request.cls.app = app + request.cls.io_thread = io_thread + request.cls.port = port # Make the port available to tests + yield def test_send_recv_msg(self): @csp.node @@ -56,6 +117,113 @@ def g(): msgs = csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) assert msgs["recv"][0][1] == "Hello, World!" + @pytest.mark.parametrize("binary", [False, True]) + def test_headers(self, binary): + @csp.graph + def g(dynamic: bool): + if dynamic: + ws = WebsocketAdapterManager(dynamic=True, binary=binary) + conn_request1 = csp.const( + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + on_connect_payload="header1", + headers={"header1": "value1"}, + ) + ] + ) + conn_request2 = csp.const( + [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.DISCONNECT)], + delay=timedelta(milliseconds=100), + ) + conn_request3 = csp.const( + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + on_connect_payload="header1", + headers={"header1": "value2"}, + ) + ], + delay=timedelta(milliseconds=150), + ) + conn_request4 = csp.const( + [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.PING)], + delay=timedelta(milliseconds=151), + ) + conn_request5 = csp.const( + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + on_connect_payload="header1", + headers={"header1": "value2"}, + ) + ], + delay=timedelta(milliseconds=200), + ) + conn_req = csp.flatten([conn_request1, conn_request2, conn_request3, conn_request4, conn_request5]) + status = ws.status() + csp.add_graph_output("status", status) + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_req) + csp.add_graph_output("recv", recv) + stop = csp.filter(csp.count(recv) == 2, recv) + csp.stop_engine(stop) + + if not dynamic: + ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/", headers={"header1": "value1"}) + status = ws.status() + send_msg = csp.sample(status, csp.const("header1")) + to_send = csp.merge(send_msg, csp.const("header1", delay=timedelta(milliseconds=100))) + ws.send(to_send) + recv = ws.subscribe(str, RawTextMessageMapper()) + + header_update = csp.const( + [WebsocketHeaderUpdate(key="header1", value="value2")], delay=timedelta(milliseconds=50) + ) + ws.update_headers(header_update) + status = ws.status() + csp.add_graph_output("status", status) + + csp.add_graph_output("recv", recv) + csp.stop_engine(recv) + + @pytest.mark.parametrize("send_payload_subscribe", [True, False]) + def test_send_recv_json_dynamic_on_connect_payload(self, send_payload_subscribe): + class MsgStruct(csp.Struct): + a: int + b: str + + @csp.graph + def g(): + ws = WebsocketAdapterManager(dynamic=True) + conn_request = ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + action=ActionType.CONNECT, + on_connect_payload=MsgStruct(a=1234, b="im a string").to_json(), + ) + if not send_payload_subscribe: + # We send payload via the dummy send function + # The 'on_connect_payload sends the result + ws.send(csp.null_ts(object), connection_request=csp.const([conn_request])) + subscribe_connection_request = ( + [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.CONNECT)] + if not send_payload_subscribe + else [conn_request] + ) + recv = ws.subscribe( + MsgStruct, JSONTextMessageMapper(), connection_request=csp.const(subscribe_connection_request) + ) + + csp.add_graph_output("recv", recv) + csp.stop_engine(recv) + + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) + obj = msgs["recv"][0][1] + assert obj.uri == f"ws://localhost:{self.port}/" + true_obj = obj.msg + assert isinstance(true_obj, MsgStruct) + assert true_obj.a == 1234 + assert true_obj.b == "im a string" + def test_send_recv_json(self): class MsgStruct(csp.Struct): a: int @@ -120,37 +288,254 @@ def g(n: int): assert len(msgs["recv"]) == n assert msgs["recv"][0][1] != msgs["recv"][-1][1] - def test_unkown_host_graceful_shutdown(self): + def test_send_multiple_and_recv_msgs_dynamic(self): @csp.graph def g(): - ws = WebsocketAdapterManager("wss://localhost/") - assert ws._properties["port"] == "443" - csp.stop_engine(ws.status()) + ws = WebsocketAdapterManager(dynamic=True) + conn_request = csp.const( + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + action=ActionType.CONNECT, + ) + ] + ) + val = csp.curve(int, [(timedelta(milliseconds=100), 0), (timedelta(milliseconds=500), 1)]) + hello = csp.apply(val, lambda x: f"hi world{x}", str) + delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=300)) - csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) + # We connect immediately and send out the hello message + ws.send(hello, connection_request=conn_request) - def test_send_recv_burst_json(self): - class MsgStruct(csp.Struct): - a: int - b: str + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=delayed_conn_req) + # This call connects first + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + + merged = csp.flatten([recv, recv2]) + csp.add_graph_output("recv", merged.msg) + + stop = csp.filter(csp.count(merged) == 3, merged) + csp.stop_engine(stop) + + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=1), realtime=True) + assert len(msgs["recv"]) == 3 + # the first message sent out, only the second subscribe call picks this up + assert msgs["recv"][0][1] == "hi world0" + # Both the subscribe calls receive this message + assert msgs["recv"][1][1] == "hi world1" + assert msgs["recv"][2][1] == "hi world1" + + @pytest.mark.parametrize("reconnect_immeditately", [False, True]) + def test_dynamic_disconnect_connect_pruned_subscribe(self, reconnect_immeditately): + @csp.graph + def g(): + ws = WebsocketAdapterManager(dynamic=True) + + if reconnect_immeditately: + disconnect_reqs = [ + ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.DISCONNECT), + ConnectionRequest(uri=f"ws://localhost:{self.port}/"), + ] + else: + disconnect_reqs = [ConnectionRequest(uri=f"ws://localhost:{self.port}/", action=ActionType.DISCONNECT)] + conn_request = csp.curve( + List[ConnectionRequest], + [ + (timedelta(), [ConnectionRequest(uri=f"ws://localhost:{self.port}/")]), + ( + timedelta(milliseconds=100), + disconnect_reqs, + ), + ( + timedelta(milliseconds=700), + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + headers={"dummy_key": "dummy_value"}, + ), + ], + ), + ], + ) + const_conn_request = csp.const([ConnectionRequest(uri=f"ws://localhost:{self.port}/")]) + val = csp.curve(int, [(timedelta(milliseconds=300, microseconds=1), 0), (timedelta(milliseconds=900), 1)]) + hello = csp.apply(val, lambda x: f"hi world{x}", str) + + # We connect immediately and send out the hello message + ws.send(hello, connection_request=const_conn_request) + + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + # This gets pruned by csp + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + recv3 = ws.subscribe(str, RawTextMessageMapper(), connection_request=const_conn_request) + + no_persist_conn = ConnectionRequest( + uri=f"ws://localhost:{self.port}/", persistent=False, on_connect_payload="hi non-persistent world!" + ) + recv4 = ws.subscribe( + str, + RawTextMessageMapper(), + connection_request=csp.const([no_persist_conn], delay=timedelta(milliseconds=500)), + ) + csp.add_graph_output("recv", recv) + csp.add_graph_output("recv3", recv3) + csp.add_graph_output("recv4", recv4) + end = csp.filter(csp.count(recv3) == 3, recv3) + csp.stop_engine(end) + + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=2), realtime=True) + # Did not persist, so did not receive any messages + assert len(msgs["recv4"]) == 0 + # Only the second message is received, since we disonnect before the first one is sent + if not reconnect_immeditately: + assert len(msgs["recv"]) == 1 + assert msgs["recv"][0][1].msg == "hi world1" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" + else: + assert len(msgs["recv"]) == 3 + assert msgs["recv"][0][1].msg == "hi world0" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" + assert msgs["recv"][1][1].msg == "hi non-persistent world!" + assert msgs["recv"][1][1].uri == f"ws://localhost:{self.port}/" + assert msgs["recv"][2][1].msg == "hi world1" + assert msgs["recv"][2][1].uri == f"ws://localhost:{self.port}/" + + # This subscribe call received all the messages + assert len(msgs["recv3"]) == 3 + assert msgs["recv3"][0][1].msg == "hi world0" + assert msgs["recv3"][0][1].uri == f"ws://localhost:{self.port}/" + assert msgs["recv3"][1][1].msg == "hi non-persistent world!" + assert msgs["recv3"][1][1].uri == f"ws://localhost:{self.port}/" + assert msgs["recv3"][2][1].msg == "hi world1" + assert msgs["recv3"][2][1].uri == f"ws://localhost:{self.port}/" + + def test_dynamic_pruned_subscribe(self): + @csp.graph + def g(): + ws = WebsocketAdapterManager(dynamic=True) + conn_request = csp.const( + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + action=ActionType.CONNECT, + ) + ] + ) + val = csp.curve(int, [(timedelta(milliseconds=100), 0), (timedelta(milliseconds=600), 1)]) + hello = csp.apply(val, lambda x: f"hi world{x}", str) + delayed_conn_req = csp.delay(conn_request, delay=timedelta(milliseconds=400)) + + # We connect immediately and send out the hello message + ws.send(hello, connection_request=conn_request) + + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=delayed_conn_req) + # This gets pruned by csp + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request) + + csp.add_graph_output("recv", recv) + csp.stop_engine(recv) + + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=2), realtime=True) + assert len(msgs["recv"]) == 1 + # Only the second message is received + assert msgs["recv"][0][1].msg == "hi world1" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" + + def test_dynamic_multiple_subscribers(self): @csp.node - def send_msg_on_open(status: ts[Status]) -> ts[str]: + def send_on_status(status: ts[Status], uri: str, val: str) -> ts[str]: if csp.ticked(status): - return MsgStruct(a=1234, b="im a string").to_json() + if uri in status.msg and status.status_code == WebsocketStatus.ACTIVE.value: + return val + + with tornado_server() as port2: # Get a second dynamic port + + @csp.graph + def g(use_on_connect_payload: bool): + ws = WebsocketAdapterManager(dynamic=True) + if use_on_connect_payload: + conn_request1 = csp.const( + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", on_connect_payload="hey world from main" + ) + ] + ) + conn_request2 = csp.const( + [ConnectionRequest(uri=f"ws://localhost:{port2}/", on_connect_payload="hey world from second")] + ) + else: + conn_request1 = csp.const([ConnectionRequest(uri=f"ws://localhost:{self.port}/")]) + conn_request2 = csp.const([ConnectionRequest(uri=f"ws://localhost:{port2}/")]) + status = ws.status() + to_send = send_on_status(status, f"ws://localhost:{self.port}/", "hey world from main") + to_send2 = send_on_status(status, f"ws://localhost:{port2}/", "hey world from second") + ws.send(to_send, connection_request=conn_request1) + ws.send(to_send2, connection_request=conn_request2) + + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request1) + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request2) + + csp.add_graph_output("recv", recv) + csp.add_graph_output("recv2", recv2) + + merged = csp.flatten([recv, recv2]) + stop = csp.filter(csp.count(merged) == 2, merged) + csp.stop_engine(stop) + + for use_on_connect_payload in [True, False]: + msgs = csp.run( + g, + use_on_connect_payload, + starttime=datetime.now(pytz.UTC), + endtime=timedelta(seconds=5), + realtime=True, + ) + assert len(msgs["recv"]) == 1 + assert msgs["recv"][0][1].msg == "hey world from main" + assert msgs["recv"][0][1].uri == f"ws://localhost:{self.port}/" + assert len(msgs["recv2"]) == 1 + assert msgs["recv2"][0][1].msg == "hey world from second" + assert msgs["recv2"][0][1].uri == f"ws://localhost:{port2}/" + + @pytest.mark.parametrize("dynamic", [False, True]) + def test_send_recv_burst_json(self, dynamic): + class MsgStruct(csp.Struct): + a: int + b: str @csp.node - def my_edge_that_handles_burst(objs: ts[List[MsgStruct]]) -> ts[bool]: + def my_edge_that_handles_burst(objs: ts[List[MsgStruct]]): if csp.ticked(objs): - return True + # Does nothing but makes sure it's not pruned + ... @csp.graph def g(): - ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/") - status = ws.status() - ws.send(send_msg_on_open(status)) - recv = ws.subscribe(MsgStruct, JSONTextMessageMapper(), push_mode=csp.PushMode.BURST) - _ = my_edge_that_handles_burst(recv) + if dynamic: + ws = WebsocketAdapterManager(dynamic=True) + wrapped_recv = ws.subscribe( + MsgStruct, + JSONTextMessageMapper(), + push_mode=csp.PushMode.BURST, + connection_request=csp.const( + [ + ConnectionRequest( + uri=f"ws://localhost:{self.port}/", + on_connect_payload=MsgStruct(a=1234, b="im a string").to_json(), + ) + ] + ), + ) + recv = csp.apply(wrapped_recv, lambda vals: [v.msg for v in vals], List[MsgStruct]) + else: + ws = WebsocketAdapterManager(f"ws://localhost:{self.port}/") + status = ws.status() + ws.send(csp.apply(status, lambda _x: MsgStruct(a=1234, b="im a string").to_json(), str)) + recv = ws.subscribe(MsgStruct, JSONTextMessageMapper(), push_mode=csp.PushMode.BURST) + + my_edge_that_handles_burst(recv) csp.add_graph_output("recv", recv) csp.stop_engine(recv) @@ -160,3 +545,28 @@ def g(): innerObj = obj[0] assert innerObj.a == 1234 assert innerObj.b == "im a string" + + def test_unkown_host_graceful_shutdown(self): + @csp.graph + def g(): + ws = WebsocketAdapterManager("wss://localhost/") + # We need this since without any input or output + # adapters, the websocket connection is not actually made. + ws.send(csp.null_ts(str)) + assert ws._properties["port"] == "443" + csp.stop_engine(ws.status()) + + csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True) + + def test_unkown_host_graceful_shutdown_slow(self): + @csp.graph + def g(): + ws = WebsocketAdapterManager("wss://localhost/") + # We need this since without any input or output + # adapters, the websocket connection is not actually made. + ws.send(csp.null_ts(str)) + assert ws._properties["port"] == "443" + stop_flag = csp.filter(csp.count(ws.status()) == 2, ws.status()) + csp.stop_engine(stop_flag) + + csp.run(g, starttime=datetime.now(pytz.UTC), realtime=True)