diff options
Diffstat (limited to 'modules/websocket')
20 files changed, 171 insertions, 121 deletions
diff --git a/modules/websocket/SCsub b/modules/websocket/SCsub index 63c941c4a8..dc0661995f 100644 --- a/modules/websocket/SCsub +++ b/modules/websocket/SCsub @@ -41,6 +41,8 @@ elif env["builtin_wslay"]: module_obj = [] env_ws.add_source_files(module_obj, "*.cpp") +if env["tools"]: + env_ws.add_source_files(module_obj, "editor/*.cpp") env.modules_sources += module_obj # Needed to force rebuilding the module files when the thirdparty library is updated. diff --git a/modules/websocket/doc_classes/WebSocketServer.xml b/modules/websocket/doc_classes/WebSocketServer.xml index ef3279aac4..46b0274de3 100644 --- a/modules/websocket/doc_classes/WebSocketServer.xml +++ b/modules/websocket/doc_classes/WebSocketServer.xml @@ -60,6 +60,13 @@ If [code]false[/code] is passed instead (default), you must call [PacketPeer] functions ([code]put_packet[/code], [code]get_packet[/code], etc.), on the [WebSocketPeer] returned via [code]get_peer(id)[/code] to communicate with the peer with given [code]id[/code] (e.g. [code]get_peer(id).get_available_packet_count[/code]). </description> </method> + <method name="set_extra_headers"> + <return type="void" /> + <argument index="0" name="headers" type="PackedStringArray" default="PackedStringArray()" /> + <description> + Sets additional headers to be sent to clients during the HTTP handshake. + </description> + </method> <method name="stop"> <return type="void" /> <description> diff --git a/modules/websocket/editor_debugger_server_websocket.cpp b/modules/websocket/editor/editor_debugger_server_websocket.cpp index 4f1a56f00b..0443147d98 100644 --- a/modules/websocket/editor_debugger_server_websocket.cpp +++ b/modules/websocket/editor/editor_debugger_server_websocket.cpp @@ -30,11 +30,13 @@ #include "editor_debugger_server_websocket.h" +#ifdef TOOLS_ENABLED + +#include "../remote_debugger_peer_websocket.h" #include "core/config/project_settings.h" #include "editor/editor_log.h" #include "editor/editor_node.h" #include "editor/editor_settings.h" -#include "modules/websocket/remote_debugger_peer_websocket.h" void EditorDebuggerServerWebSocket::_peer_connected(int p_id, String _protocol) { pending_peers.push_back(p_id); @@ -129,3 +131,5 @@ EditorDebuggerServer *EditorDebuggerServerWebSocket::create(const String &p_prot ERR_FAIL_COND_V(p_protocol != "ws://", nullptr); return memnew(EditorDebuggerServerWebSocket); } + +#endif // TOOLS_ENABLED diff --git a/modules/websocket/editor_debugger_server_websocket.h b/modules/websocket/editor/editor_debugger_server_websocket.h index cc14bf62ba..7c0705302d 100644 --- a/modules/websocket/editor_debugger_server_websocket.h +++ b/modules/websocket/editor/editor_debugger_server_websocket.h @@ -31,8 +31,10 @@ #ifndef EDITOR_DEBUGGER_SERVER_WEBSOCKET_H #define EDITOR_DEBUGGER_SERVER_WEBSOCKET_H +#ifdef TOOLS_ENABLED + +#include "../websocket_server.h" #include "editor/debugger/editor_debugger_server.h" -#include "modules/websocket/websocket_server.h" class EditorDebuggerServerWebSocket : public EditorDebuggerServer { GDCLASS(EditorDebuggerServerWebSocket, EditorDebuggerServer); @@ -60,4 +62,6 @@ public: ~EditorDebuggerServerWebSocket(); }; +#endif // TOOLS_ENABLED + #endif // EDITOR_DEBUGGER_SERVER_WEBSOCKET_H diff --git a/modules/websocket/emws_client.h b/modules/websocket/emws_client.h index 61ea0002ea..ca327a56fa 100644 --- a/modules/websocket/emws_client.h +++ b/modules/websocket/emws_client.h @@ -53,15 +53,15 @@ private: static void _esws_on_close(void *obj, int code, const char *reason, int was_clean); public: - Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets); - Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()); - Ref<WebSocketPeer> get_peer(int p_peer_id) const; - void disconnect_from_host(int p_code = 1000, String p_reason = ""); - IPAddress get_connected_host() const; - uint16_t get_connected_port() const; - virtual ConnectionStatus get_connection_status() const; - int get_max_packet_size() const; - virtual void poll(); + Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override; + Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()) override; + Ref<WebSocketPeer> get_peer(int p_peer_id) const override; + void disconnect_from_host(int p_code = 1000, String p_reason = "") override; + IPAddress get_connected_host() const override; + uint16_t get_connected_port() const override; + virtual ConnectionStatus get_connection_status() const override; + int get_max_packet_size() const override; + virtual void poll() override; EMWSClient(); ~EMWSClient(); }; diff --git a/modules/websocket/emws_peer.h b/modules/websocket/emws_peer.h index df63d2d801..6bb4552c37 100644 --- a/modules/websocket/emws_peer.h +++ b/modules/websocket/emws_peer.h @@ -68,21 +68,21 @@ private: public: Error read_msg(const uint8_t *p_data, uint32_t p_size, bool p_is_string); void set_sock(int p_sock, unsigned int p_in_buf_size, unsigned int p_in_pkt_size, unsigned int p_out_buf_size); - virtual int get_available_packet_count() const; - virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size); - virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size); - virtual int get_max_packet_size() const { return _packet_buffer.size(); }; - virtual int get_current_outbound_buffered_amount() const; + virtual int get_available_packet_count() const override; + virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size) override; + virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size) override; + virtual int get_max_packet_size() const override { return _packet_buffer.size(); }; + virtual int get_current_outbound_buffered_amount() const override; - virtual void close(int p_code = 1000, String p_reason = ""); - virtual bool is_connected_to_host() const; - virtual IPAddress get_connected_host() const; - virtual uint16_t get_connected_port() const; + virtual void close(int p_code = 1000, String p_reason = "") override; + virtual bool is_connected_to_host() const override; + virtual IPAddress get_connected_host() const override; + virtual uint16_t get_connected_port() const override; - virtual WriteMode get_write_mode() const; - virtual void set_write_mode(WriteMode p_mode); - virtual bool was_string_packet() const; - virtual void set_no_delay(bool p_enabled); + virtual WriteMode get_write_mode() const override; + virtual void set_write_mode(WriteMode p_mode) override; + virtual bool was_string_packet() const override; + virtual void set_no_delay(bool p_enabled) override; EMWSPeer(); ~EMWSPeer(); diff --git a/modules/websocket/emws_server.cpp b/modules/websocket/emws_server.cpp index 53b4a0207d..2033098cad 100644 --- a/modules/websocket/emws_server.cpp +++ b/modules/websocket/emws_server.cpp @@ -33,6 +33,9 @@ #include "emws_server.h" #include "core/os/os.h" +void EMWSServer::set_extra_headers(const Vector<String> &p_headers) { +} + Error EMWSServer::listen(int p_port, Vector<String> p_protocols, bool gd_mp_api) { return FAILED; } diff --git a/modules/websocket/emws_server.h b/modules/websocket/emws_server.h index f310c17c9d..ae31d9dbb0 100644 --- a/modules/websocket/emws_server.h +++ b/modules/websocket/emws_server.h @@ -41,17 +41,18 @@ class EMWSServer : public WebSocketServer { GDCIIMPL(EMWSServer, WebSocketServer); public: - Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets); - Error listen(int p_port, Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false); - void stop(); - bool is_listening() const; - bool has_peer(int p_id) const; - Ref<WebSocketPeer> get_peer(int p_id) const; - IPAddress get_peer_address(int p_peer_id) const; - int get_peer_port(int p_peer_id) const; - void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = ""); - int get_max_packet_size() const; - virtual void poll(); + Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override; + void set_extra_headers(const Vector<String> &p_headers) override; + Error listen(int p_port, Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) override; + void stop() override; + bool is_listening() const override; + bool has_peer(int p_id) const override; + Ref<WebSocketPeer> get_peer(int p_id) const override; + IPAddress get_peer_address(int p_peer_id) const override; + int get_peer_port(int p_peer_id) const override; + void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "") override; + int get_max_packet_size() const override; + virtual void poll() override; virtual Vector<String> get_protocols() const; EMWSServer(); diff --git a/modules/websocket/library_godot_websocket.js b/modules/websocket/library_godot_websocket.js index c88986fbe3..57f1f10b02 100644 --- a/modules/websocket/library_godot_websocket.js +++ b/modules/websocket/library_godot_websocket.js @@ -135,7 +135,7 @@ const GodotWebSocket = { if (!ref) { return; } - GodotWebSocket.close(p_id, 1001, ''); + GodotWebSocket.close(p_id, 3001, 'destroyed'); IDHandler.remove(p_id); ref.onopen = null; ref.onmessage = null; diff --git a/modules/websocket/register_types.cpp b/modules/websocket/register_types.cpp index 1e9a4c0392..f562de111f 100644 --- a/modules/websocket/register_types.cpp +++ b/modules/websocket/register_types.cpp @@ -29,8 +29,10 @@ /*************************************************************************/ #include "register_types.h" + #include "core/config/project_settings.h" #include "core/error/error_macros.h" + #ifdef JAVASCRIPT_ENABLED #include "emscripten.h" #include "emws_client.h" @@ -40,10 +42,11 @@ #include "wsl_client.h" #include "wsl_server.h" #endif + #ifdef TOOLS_ENABLED #include "editor/debugger/editor_debugger_server.h" +#include "editor/editor_debugger_server_websocket.h" #include "editor/editor_node.h" -#include "editor_debugger_server_websocket.h" #endif #ifdef TOOLS_ENABLED @@ -52,25 +55,33 @@ static void _editor_init_callback() { } #endif -void register_websocket_types() { +void initialize_websocket_module(ModuleInitializationLevel p_level) { + if (p_level == MODULE_INITIALIZATION_LEVEL_SCENE) { #ifdef JAVASCRIPT_ENABLED - EMWSPeer::make_default(); - EMWSClient::make_default(); - EMWSServer::make_default(); + EMWSPeer::make_default(); + EMWSClient::make_default(); + EMWSServer::make_default(); #else - WSLPeer::make_default(); - WSLClient::make_default(); - WSLServer::make_default(); + WSLPeer::make_default(); + WSLClient::make_default(); + WSLServer::make_default(); #endif - GDREGISTER_VIRTUAL_CLASS(WebSocketMultiplayerPeer); - ClassDB::register_custom_instance_class<WebSocketServer>(); - ClassDB::register_custom_instance_class<WebSocketClient>(); - ClassDB::register_custom_instance_class<WebSocketPeer>(); + GDREGISTER_ABSTRACT_CLASS(WebSocketMultiplayerPeer); + ClassDB::register_custom_instance_class<WebSocketServer>(); + ClassDB::register_custom_instance_class<WebSocketClient>(); + ClassDB::register_custom_instance_class<WebSocketPeer>(); + } #ifdef TOOLS_ENABLED - EditorNode::add_init_callback(&_editor_init_callback); + if (p_level == MODULE_INITIALIZATION_LEVEL_EDITOR) { + EditorNode::add_init_callback(&_editor_init_callback); + } #endif } -void unregister_websocket_types() {} +void uninitialize_websocket_module(ModuleInitializationLevel p_level) { + if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) { + return; + } +} diff --git a/modules/websocket/register_types.h b/modules/websocket/register_types.h index 4ab6c0cfd3..dab42d6ed9 100644 --- a/modules/websocket/register_types.h +++ b/modules/websocket/register_types.h @@ -31,7 +31,9 @@ #ifndef WEBSOCKET_REGISTER_TYPES_H #define WEBSOCKET_REGISTER_TYPES_H -void register_websocket_types(); -void unregister_websocket_types(); +#include "modules/register_module_types.h" + +void initialize_websocket_module(ModuleInitializationLevel p_level); +void uninitialize_websocket_module(ModuleInitializationLevel p_level); #endif // WEBSOCKET_REGISTER_TYPES_H diff --git a/modules/websocket/remote_debugger_peer_websocket.h b/modules/websocket/remote_debugger_peer_websocket.h index ddf5425d81..3227065ded 100644 --- a/modules/websocket/remote_debugger_peer_websocket.h +++ b/modules/websocket/remote_debugger_peer_websocket.h @@ -31,12 +31,13 @@ #ifndef REMOTE_DEBUGGER_PEER_WEBSOCKET_H #define REMOTE_DEBUGGER_PEER_WEBSOCKET_H +#include "core/debugger/remote_debugger_peer.h" + #ifdef JAVASCRIPT_ENABLED -#include "modules/websocket/emws_client.h" +#include "emws_client.h" #else -#include "modules/websocket/wsl_client.h" +#include "wsl_client.h" #endif -#include "core/debugger/remote_debugger_peer.h" class RemoteDebuggerPeerWebSocket : public RemoteDebuggerPeer { Ref<WebSocketClient> ws_client; @@ -50,14 +51,14 @@ public: static RemoteDebuggerPeer *create(const String &p_uri); Error connect_to_host(const String &p_uri); - bool is_peer_connected(); - int get_max_message_size() const; - bool has_message(); - Error put_message(const Array &p_arr); - Array get_message(); - void close(); - void poll(); - bool can_block() const; + bool is_peer_connected() override; + int get_max_message_size() const override; + bool has_message() override; + Error put_message(const Array &p_arr) override; + Array get_message() override; + void close() override; + void poll() override; + bool can_block() const override; RemoteDebuggerPeerWebSocket(Ref<WebSocketPeer> p_peer = Ref<WebSocketPeer>()); }; diff --git a/modules/websocket/websocket_server.cpp b/modules/websocket/websocket_server.cpp index b3f0140b80..b7851b02c4 100644 --- a/modules/websocket/websocket_server.cpp +++ b/modules/websocket/websocket_server.cpp @@ -42,6 +42,7 @@ WebSocketServer::~WebSocketServer() { void WebSocketServer::_bind_methods() { ClassDB::bind_method(D_METHOD("is_listening"), &WebSocketServer::is_listening); + ClassDB::bind_method(D_METHOD("set_extra_headers", "headers"), &WebSocketServer::set_extra_headers, DEFVAL(Vector<String>())); ClassDB::bind_method(D_METHOD("listen", "port", "protocols", "gd_mp_api"), &WebSocketServer::listen, DEFVAL(Vector<String>()), DEFVAL(false)); ClassDB::bind_method(D_METHOD("stop"), &WebSocketServer::stop); ClassDB::bind_method(D_METHOD("has_peer", "id"), &WebSocketServer::has_peer); diff --git a/modules/websocket/websocket_server.h b/modules/websocket/websocket_server.h index f6f3b80045..7bd80851f5 100644 --- a/modules/websocket/websocket_server.h +++ b/modules/websocket/websocket_server.h @@ -51,6 +51,7 @@ protected: uint32_t handshake_timeout = 3000; public: + virtual void set_extra_headers(const Vector<String> &p_headers) = 0; virtual Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) = 0; virtual void stop() = 0; virtual bool is_listening() const = 0; diff --git a/modules/websocket/wsl_client.cpp b/modules/websocket/wsl_client.cpp index 1ef571b6ee..894ba7766f 100644 --- a/modules/websocket/wsl_client.cpp +++ b/modules/websocket/wsl_client.cpp @@ -91,6 +91,7 @@ void WSLClient::_do_handshake() { data->id = 1; _peer->make_context(data, _in_buf_size, _in_pkt_size, _out_buf_size, _out_pkt_size); _peer->set_no_delay(true); + _status = CONNECTION_CONNECTED; _on_connect(protocol); break; } @@ -103,13 +104,14 @@ bool WSLClient::_verify_headers(String &r_protocol) { String s = (char *)_resp_buf; Vector<String> psa = s.split("\r\n"); int len = psa.size(); - ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers, got: " + itos(len) + ", expected >= 4."); + ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers. Got: " + itos(len) + ", expected >= 4."); Vector<String> req = psa[0].split(" ", false); - ERR_FAIL_COND_V_MSG(req.size() < 2, false, "Invalid protocol or status code."); + ERR_FAIL_COND_V_MSG(req.size() < 2, false, "Invalid protocol or status code. Got '" + psa[0] + "', expected 'HTTP/1.1 101'."); // Wrong protocol - ERR_FAIL_COND_V_MSG(req[0] != "HTTP/1.1" || req[1] != "101", false, "Invalid protocol or status code."); + ERR_FAIL_COND_V_MSG(req[0] != "HTTP/1.1", false, "Invalid protocol. Got: '" + req[0] + "', expected 'HTTP/1.1'."); + ERR_FAIL_COND_V_MSG(req[1] != "101", false, "Invalid status code. Got: '" + req[1] + "', expected '101'."); Map<String, String> headers; for (int i = 1; i < len; i++) { @@ -137,9 +139,11 @@ bool WSLClient::_verify_headers(String &r_protocol) { #undef WSL_CHECK if (_protocols.size() == 0) { // We didn't request a custom protocol - ERR_FAIL_COND_V(headers.has("sec-websocket-protocol"), false); + ERR_FAIL_COND_V_MSG(headers.has("sec-websocket-protocol"), false, "Received unrequested sub-protocol -> " + headers["sec-websocket-protocol"]); } else { - ERR_FAIL_COND_V(!headers.has("sec-websocket-protocol"), false); + // We requested at least one custom protocol but didn't receive one + ERR_FAIL_COND_V_MSG(!headers.has("sec-websocket-protocol"), false, "Requested sub-protocol(s) but received none."); + // Check received sub-protocol was one of those requested. r_protocol = headers["sec-websocket-protocol"]; bool valid = false; for (int i = 0; i < _protocols.size(); i++) { @@ -150,6 +154,7 @@ bool WSLClient::_verify_headers(String &r_protocol) { break; } if (!valid) { + ERR_FAIL_V_MSG(false, "Received unrequested sub-protocol -> " + r_protocol); return false; } } @@ -227,6 +232,7 @@ Error WSLClient::connect_to_host(String p_host, String p_path, uint16_t p_port, } request += "\r\n"; _request = request.utf8(); + _status = CONNECTION_CONNECTING; return OK; } @@ -273,6 +279,7 @@ void WSLClient::poll() { return; // Not connected. } + _tcp->poll(); switch (_tcp->get_status()) { case StreamPeerTCP::STATUS_NONE: // Clean close @@ -332,21 +339,19 @@ Ref<WebSocketPeer> WSLClient::get_peer(int p_peer_id) const { } MultiplayerPeer::ConnectionStatus WSLClient::get_connection_status() const { + // This is surprising, but keeps the current behaviour to allow clean close requests. + // TODO Refactor WebSocket and split Client/Server/Multiplayer like done in other peers. if (_peer->is_connected_to_host()) { return CONNECTION_CONNECTED; } - - if (_tcp->is_connected_to_host() || _resolver_id != IP::RESOLVER_INVALID_ID) { - return CONNECTION_CONNECTING; - } - - return CONNECTION_DISCONNECTED; + return _status; } void WSLClient::disconnect_from_host(int p_code, String p_reason) { _peer->close(p_code, p_reason); _connection = Ref<StreamPeer>(nullptr); _tcp = Ref<StreamPeerTCP>(memnew(StreamPeerTCP)); + _status = CONNECTION_DISCONNECTED; _key = ""; _host = ""; diff --git a/modules/websocket/wsl_client.h b/modules/websocket/wsl_client.h index d846e6be00..22d7ffa839 100644 --- a/modules/websocket/wsl_client.h +++ b/modules/websocket/wsl_client.h @@ -52,6 +52,7 @@ private: Ref<WSLPeer> _peer; Ref<StreamPeerTCP> _tcp; Ref<StreamPeer> _connection; + ConnectionStatus _status = CONNECTION_DISCONNECTED; CharString _request; int _requested = 0; @@ -59,11 +60,9 @@ private: uint8_t _resp_buf[WSL_MAX_HEADER_SIZE]; int _resp_pos = 0; - String _response; - String _key; String _host; - uint16_t _port; + uint16_t _port = 0; Array _ip_candidates; Vector<String> _protocols; bool _use_ssl = false; @@ -73,15 +72,15 @@ private: bool _verify_headers(String &r_protocol); public: - Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets); - Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()); - int get_max_packet_size() const; - Ref<WebSocketPeer> get_peer(int p_peer_id) const; - void disconnect_from_host(int p_code = 1000, String p_reason = ""); - IPAddress get_connected_host() const; - uint16_t get_connected_port() const; - virtual ConnectionStatus get_connection_status() const; - virtual void poll(); + Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override; + Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()) override; + int get_max_packet_size() const override; + Ref<WebSocketPeer> get_peer(int p_peer_id) const override; + void disconnect_from_host(int p_code = 1000, String p_reason = "") override; + IPAddress get_connected_host() const override; + uint16_t get_connected_port() const override; + virtual ConnectionStatus get_connection_status() const override; + virtual void poll() override; WSLClient(); ~WSLClient(); diff --git a/modules/websocket/wsl_peer.cpp b/modules/websocket/wsl_peer.cpp index d277eedace..15df4d039c 100644 --- a/modules/websocket/wsl_peer.cpp +++ b/modules/websocket/wsl_peer.cpp @@ -146,17 +146,17 @@ void wsl_msg_recv_callback(wslay_event_context_ptr ctx, const struct wslay_event if (!peer_data->valid || peer_data->closing) { return; } - WSLPeer *peer = (WSLPeer *)peer_data->peer; + WSLPeer *peer = static_cast<WSLPeer *>(peer_data->peer); if (peer->parse_message(arg) != OK) { return; } if (peer_data->is_server) { - WSLServer *helper = (WSLServer *)peer_data->obj; + WSLServer *helper = static_cast<WSLServer *>(peer_data->obj); helper->_on_peer_packet(peer_data->id); } else { - WSLClient *helper = (WSLClient *)peer_data->obj; + WSLClient *helper = static_cast<WSLClient *>(peer_data->obj); helper->_on_peer_packet(); } } @@ -184,10 +184,10 @@ Error WSLPeer::parse_message(const wslay_event_on_msg_recv_arg *arg) { } if (!wslay_event_get_close_sent(_data->ctx)) { if (_data->is_server) { - WSLServer *helper = (WSLServer *)_data->obj; + WSLServer *helper = static_cast<WSLServer *>(_data->obj); helper->_on_close_request(_data->id, close_code, close_reason); } else { - WSLClient *helper = (WSLClient *)_data->obj; + WSLClient *helper = static_cast<WSLClient *>(_data->obj); helper->_on_close_request(close_code, close_reason); } } diff --git a/modules/websocket/wsl_peer.h b/modules/websocket/wsl_peer.h index 555559c6e1..abeecdd537 100644 --- a/modules/websocket/wsl_peer.h +++ b/modules/websocket/wsl_peer.h @@ -85,22 +85,22 @@ public: String close_reason; void poll(); // Used by client and server. - virtual int get_available_packet_count() const; - virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size); - virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size); - virtual int get_max_packet_size() const { return _packet_buffer.size(); }; - virtual int get_current_outbound_buffered_amount() const; + virtual int get_available_packet_count() const override; + virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size) override; + virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size) override; + virtual int get_max_packet_size() const override { return _packet_buffer.size(); }; + virtual int get_current_outbound_buffered_amount() const override; virtual void close_now(); - virtual void close(int p_code = 1000, String p_reason = ""); - virtual bool is_connected_to_host() const; - virtual IPAddress get_connected_host() const; - virtual uint16_t get_connected_port() const; - - virtual WriteMode get_write_mode() const; - virtual void set_write_mode(WriteMode p_mode); - virtual bool was_string_packet() const; - virtual void set_no_delay(bool p_enabled); + virtual void close(int p_code = 1000, String p_reason = "") override; + virtual bool is_connected_to_host() const override; + virtual IPAddress get_connected_host() const override; + virtual uint16_t get_connected_port() const override; + + virtual WriteMode get_write_mode() const override; + virtual void set_write_mode(WriteMode p_mode) override; + virtual bool was_string_packet() const override; + virtual void set_no_delay(bool p_enabled) override; void make_context(PeerData *p_data, unsigned int p_in_buf_size, unsigned int p_in_pkt_size, unsigned int p_out_buf_size, unsigned int p_out_pkt_size); Error parse_message(const wslay_event_on_msg_recv_arg *arg); diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp index eadd7ef7ac..b58b2e4724 100644 --- a/modules/websocket/wsl_server.cpp +++ b/modules/websocket/wsl_server.cpp @@ -96,7 +96,7 @@ bool WSLServer::PendingPeer::_parse_request(const Vector<String> p_protocols, St return true; } -Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name) { +Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector<String> &p_extra_headers) { if (OS::get_singleton()->get_ticks_msec() - time > p_timeout) { print_verbose(vformat("WebSocket handshake timed out after %.3f seconds.", p_timeout * 0.001)); return ERR_TIMEOUT; @@ -141,6 +141,9 @@ Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uin if (!protocol.is_empty()) { s += "Sec-WebSocket-Protocol: " + protocol + "\r\n"; } + for (int i = 0; i < p_extra_headers.size(); i++) { + s += p_extra_headers[i] + "\r\n"; + } s += "\r\n"; response = s.utf8(); has_request = true; @@ -167,6 +170,10 @@ Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uin return OK; } +void WSLServer::set_extra_headers(const Vector<String> &p_headers) { + _extra_headers = p_headers; +} + Error WSLServer::listen(int p_port, const Vector<String> p_protocols, bool gd_mp_api) { ERR_FAIL_COND_V(is_listening(), ERR_ALREADY_IN_USE); @@ -183,7 +190,7 @@ Error WSLServer::listen(int p_port, const Vector<String> p_protocols, bool gd_mp void WSLServer::poll() { List<int> remove_ids; for (const KeyValue<int, Ref<WebSocketPeer>> &E : _peer_map) { - Ref<WSLPeer> peer = (WSLPeer *)E.value.ptr(); + Ref<WSLPeer> peer = const_cast<WSLPeer *>(static_cast<const WSLPeer *>(E.value.ptr())); peer->poll(); if (!peer->is_connected_to_host()) { _on_disconnect(E.key, peer->close_code != -1); @@ -199,7 +206,7 @@ void WSLServer::poll() { for (const Ref<PendingPeer> &E : _pending) { String resource_name; Ref<PendingPeer> ppeer = E; - Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name); + Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name, _extra_headers); if (err == ERR_BUSY) { continue; } else if (err != OK) { @@ -266,7 +273,7 @@ int WSLServer::get_max_packet_size() const { void WSLServer::stop() { _server->stop(); for (const KeyValue<int, Ref<WebSocketPeer>> &E : _peer_map) { - Ref<WSLPeer> peer = (WSLPeer *)E.value.ptr(); + Ref<WSLPeer> peer = const_cast<WSLPeer *>(static_cast<const WSLPeer *>(E.value.ptr())); peer->close_now(); } _pending.clear(); diff --git a/modules/websocket/wsl_server.h b/modules/websocket/wsl_server.h index 221cae4793..a920e9c665 100644 --- a/modules/websocket/wsl_server.h +++ b/modules/websocket/wsl_server.h @@ -62,7 +62,7 @@ private: CharString response; int response_sent = 0; - Error do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name); + Error do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector<String> &p_extra_headers); }; int _in_buf_size = DEF_BUF_SHIFT; @@ -73,19 +73,21 @@ private: List<Ref<PendingPeer>> _pending; Ref<TCPServer> _server; Vector<String> _protocols; + Vector<String> _extra_headers; public: - Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets); - Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false); - void stop(); - bool is_listening() const; - int get_max_packet_size() const; - bool has_peer(int p_id) const; - Ref<WebSocketPeer> get_peer(int p_id) const; - IPAddress get_peer_address(int p_peer_id) const; - int get_peer_port(int p_peer_id) const; - void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = ""); - virtual void poll(); + Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override; + void set_extra_headers(const Vector<String> &p_headers) override; + Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) override; + void stop() override; + bool is_listening() const override; + int get_max_packet_size() const override; + bool has_peer(int p_id) const override; + Ref<WebSocketPeer> get_peer(int p_id) const override; + IPAddress get_peer_address(int p_peer_id) const override; + int get_peer_port(int p_peer_id) const override; + void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "") override; + virtual void poll() override; WSLServer(); ~WSLServer(); |