summaryrefslogtreecommitdiff
path: root/modules/websocket
diff options
context:
space:
mode:
Diffstat (limited to 'modules/websocket')
-rw-r--r--modules/websocket/doc_classes/WebSocketServer.xml7
-rw-r--r--modules/websocket/emws_client.h18
-rw-r--r--modules/websocket/emws_peer.h26
-rw-r--r--modules/websocket/emws_server.cpp3
-rw-r--r--modules/websocket/emws_server.h23
-rw-r--r--modules/websocket/register_types.cpp34
-rw-r--r--modules/websocket/register_types.h6
-rw-r--r--modules/websocket/remote_debugger_peer_websocket.h16
-rw-r--r--modules/websocket/websocket_server.cpp1
-rw-r--r--modules/websocket/websocket_server.h1
-rw-r--r--modules/websocket/wsl_client.cpp27
-rw-r--r--modules/websocket/wsl_client.h23
-rw-r--r--modules/websocket/wsl_peer.cpp10
-rw-r--r--modules/websocket/wsl_peer.h28
-rw-r--r--modules/websocket/wsl_server.cpp15
-rw-r--r--modules/websocket/wsl_server.h26
16 files changed, 150 insertions, 114 deletions
diff --git a/modules/websocket/doc_classes/WebSocketServer.xml b/modules/websocket/doc_classes/WebSocketServer.xml
index ef3279aac4..46b0274de3 100644
--- a/modules/websocket/doc_classes/WebSocketServer.xml
+++ b/modules/websocket/doc_classes/WebSocketServer.xml
@@ -60,6 +60,13 @@
If [code]false[/code] is passed instead (default), you must call [PacketPeer] functions ([code]put_packet[/code], [code]get_packet[/code], etc.), on the [WebSocketPeer] returned via [code]get_peer(id)[/code] to communicate with the peer with given [code]id[/code] (e.g. [code]get_peer(id).get_available_packet_count[/code]).
</description>
</method>
+ <method name="set_extra_headers">
+ <return type="void" />
+ <argument index="0" name="headers" type="PackedStringArray" default="PackedStringArray()" />
+ <description>
+ Sets additional headers to be sent to clients during the HTTP handshake.
+ </description>
+ </method>
<method name="stop">
<return type="void" />
<description>
diff --git a/modules/websocket/emws_client.h b/modules/websocket/emws_client.h
index 61ea0002ea..ca327a56fa 100644
--- a/modules/websocket/emws_client.h
+++ b/modules/websocket/emws_client.h
@@ -53,15 +53,15 @@ private:
static void _esws_on_close(void *obj, int code, const char *reason, int was_clean);
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>());
- Ref<WebSocketPeer> get_peer(int p_peer_id) const;
- void disconnect_from_host(int p_code = 1000, String p_reason = "");
- IPAddress get_connected_host() const;
- uint16_t get_connected_port() const;
- virtual ConnectionStatus get_connection_status() const;
- int get_max_packet_size() const;
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()) override;
+ Ref<WebSocketPeer> get_peer(int p_peer_id) const override;
+ void disconnect_from_host(int p_code = 1000, String p_reason = "") override;
+ IPAddress get_connected_host() const override;
+ uint16_t get_connected_port() const override;
+ virtual ConnectionStatus get_connection_status() const override;
+ int get_max_packet_size() const override;
+ virtual void poll() override;
EMWSClient();
~EMWSClient();
};
diff --git a/modules/websocket/emws_peer.h b/modules/websocket/emws_peer.h
index df63d2d801..6bb4552c37 100644
--- a/modules/websocket/emws_peer.h
+++ b/modules/websocket/emws_peer.h
@@ -68,21 +68,21 @@ private:
public:
Error read_msg(const uint8_t *p_data, uint32_t p_size, bool p_is_string);
void set_sock(int p_sock, unsigned int p_in_buf_size, unsigned int p_in_pkt_size, unsigned int p_out_buf_size);
- virtual int get_available_packet_count() const;
- virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size);
- virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size);
- virtual int get_max_packet_size() const { return _packet_buffer.size(); };
- virtual int get_current_outbound_buffered_amount() const;
+ virtual int get_available_packet_count() const override;
+ virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size) override;
+ virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size) override;
+ virtual int get_max_packet_size() const override { return _packet_buffer.size(); };
+ virtual int get_current_outbound_buffered_amount() const override;
- virtual void close(int p_code = 1000, String p_reason = "");
- virtual bool is_connected_to_host() const;
- virtual IPAddress get_connected_host() const;
- virtual uint16_t get_connected_port() const;
+ virtual void close(int p_code = 1000, String p_reason = "") override;
+ virtual bool is_connected_to_host() const override;
+ virtual IPAddress get_connected_host() const override;
+ virtual uint16_t get_connected_port() const override;
- virtual WriteMode get_write_mode() const;
- virtual void set_write_mode(WriteMode p_mode);
- virtual bool was_string_packet() const;
- virtual void set_no_delay(bool p_enabled);
+ virtual WriteMode get_write_mode() const override;
+ virtual void set_write_mode(WriteMode p_mode) override;
+ virtual bool was_string_packet() const override;
+ virtual void set_no_delay(bool p_enabled) override;
EMWSPeer();
~EMWSPeer();
diff --git a/modules/websocket/emws_server.cpp b/modules/websocket/emws_server.cpp
index 53b4a0207d..2033098cad 100644
--- a/modules/websocket/emws_server.cpp
+++ b/modules/websocket/emws_server.cpp
@@ -33,6 +33,9 @@
#include "emws_server.h"
#include "core/os/os.h"
+void EMWSServer::set_extra_headers(const Vector<String> &p_headers) {
+}
+
Error EMWSServer::listen(int p_port, Vector<String> p_protocols, bool gd_mp_api) {
return FAILED;
}
diff --git a/modules/websocket/emws_server.h b/modules/websocket/emws_server.h
index f310c17c9d..ae31d9dbb0 100644
--- a/modules/websocket/emws_server.h
+++ b/modules/websocket/emws_server.h
@@ -41,17 +41,18 @@ class EMWSServer : public WebSocketServer {
GDCIIMPL(EMWSServer, WebSocketServer);
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error listen(int p_port, Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false);
- void stop();
- bool is_listening() const;
- bool has_peer(int p_id) const;
- Ref<WebSocketPeer> get_peer(int p_id) const;
- IPAddress get_peer_address(int p_peer_id) const;
- int get_peer_port(int p_peer_id) const;
- void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "");
- int get_max_packet_size() const;
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ void set_extra_headers(const Vector<String> &p_headers) override;
+ Error listen(int p_port, Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) override;
+ void stop() override;
+ bool is_listening() const override;
+ bool has_peer(int p_id) const override;
+ Ref<WebSocketPeer> get_peer(int p_id) const override;
+ IPAddress get_peer_address(int p_peer_id) const override;
+ int get_peer_port(int p_peer_id) const override;
+ void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "") override;
+ int get_max_packet_size() const override;
+ virtual void poll() override;
virtual Vector<String> get_protocols() const;
EMWSServer();
diff --git a/modules/websocket/register_types.cpp b/modules/websocket/register_types.cpp
index 6d63938d4f..f562de111f 100644
--- a/modules/websocket/register_types.cpp
+++ b/modules/websocket/register_types.cpp
@@ -55,25 +55,33 @@ static void _editor_init_callback() {
}
#endif
-void register_websocket_types() {
+void initialize_websocket_module(ModuleInitializationLevel p_level) {
+ if (p_level == MODULE_INITIALIZATION_LEVEL_SCENE) {
#ifdef JAVASCRIPT_ENABLED
- EMWSPeer::make_default();
- EMWSClient::make_default();
- EMWSServer::make_default();
+ EMWSPeer::make_default();
+ EMWSClient::make_default();
+ EMWSServer::make_default();
#else
- WSLPeer::make_default();
- WSLClient::make_default();
- WSLServer::make_default();
+ WSLPeer::make_default();
+ WSLClient::make_default();
+ WSLServer::make_default();
#endif
- GDREGISTER_ABSTRACT_CLASS(WebSocketMultiplayerPeer);
- ClassDB::register_custom_instance_class<WebSocketServer>();
- ClassDB::register_custom_instance_class<WebSocketClient>();
- ClassDB::register_custom_instance_class<WebSocketPeer>();
+ GDREGISTER_ABSTRACT_CLASS(WebSocketMultiplayerPeer);
+ ClassDB::register_custom_instance_class<WebSocketServer>();
+ ClassDB::register_custom_instance_class<WebSocketClient>();
+ ClassDB::register_custom_instance_class<WebSocketPeer>();
+ }
#ifdef TOOLS_ENABLED
- EditorNode::add_init_callback(&_editor_init_callback);
+ if (p_level == MODULE_INITIALIZATION_LEVEL_EDITOR) {
+ EditorNode::add_init_callback(&_editor_init_callback);
+ }
#endif
}
-void unregister_websocket_types() {}
+void uninitialize_websocket_module(ModuleInitializationLevel p_level) {
+ if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) {
+ return;
+ }
+}
diff --git a/modules/websocket/register_types.h b/modules/websocket/register_types.h
index 4ab6c0cfd3..dab42d6ed9 100644
--- a/modules/websocket/register_types.h
+++ b/modules/websocket/register_types.h
@@ -31,7 +31,9 @@
#ifndef WEBSOCKET_REGISTER_TYPES_H
#define WEBSOCKET_REGISTER_TYPES_H
-void register_websocket_types();
-void unregister_websocket_types();
+#include "modules/register_module_types.h"
+
+void initialize_websocket_module(ModuleInitializationLevel p_level);
+void uninitialize_websocket_module(ModuleInitializationLevel p_level);
#endif // WEBSOCKET_REGISTER_TYPES_H
diff --git a/modules/websocket/remote_debugger_peer_websocket.h b/modules/websocket/remote_debugger_peer_websocket.h
index 84f9506625..3227065ded 100644
--- a/modules/websocket/remote_debugger_peer_websocket.h
+++ b/modules/websocket/remote_debugger_peer_websocket.h
@@ -51,14 +51,14 @@ public:
static RemoteDebuggerPeer *create(const String &p_uri);
Error connect_to_host(const String &p_uri);
- bool is_peer_connected();
- int get_max_message_size() const;
- bool has_message();
- Error put_message(const Array &p_arr);
- Array get_message();
- void close();
- void poll();
- bool can_block() const;
+ bool is_peer_connected() override;
+ int get_max_message_size() const override;
+ bool has_message() override;
+ Error put_message(const Array &p_arr) override;
+ Array get_message() override;
+ void close() override;
+ void poll() override;
+ bool can_block() const override;
RemoteDebuggerPeerWebSocket(Ref<WebSocketPeer> p_peer = Ref<WebSocketPeer>());
};
diff --git a/modules/websocket/websocket_server.cpp b/modules/websocket/websocket_server.cpp
index b3f0140b80..b7851b02c4 100644
--- a/modules/websocket/websocket_server.cpp
+++ b/modules/websocket/websocket_server.cpp
@@ -42,6 +42,7 @@ WebSocketServer::~WebSocketServer() {
void WebSocketServer::_bind_methods() {
ClassDB::bind_method(D_METHOD("is_listening"), &WebSocketServer::is_listening);
+ ClassDB::bind_method(D_METHOD("set_extra_headers", "headers"), &WebSocketServer::set_extra_headers, DEFVAL(Vector<String>()));
ClassDB::bind_method(D_METHOD("listen", "port", "protocols", "gd_mp_api"), &WebSocketServer::listen, DEFVAL(Vector<String>()), DEFVAL(false));
ClassDB::bind_method(D_METHOD("stop"), &WebSocketServer::stop);
ClassDB::bind_method(D_METHOD("has_peer", "id"), &WebSocketServer::has_peer);
diff --git a/modules/websocket/websocket_server.h b/modules/websocket/websocket_server.h
index f6f3b80045..7bd80851f5 100644
--- a/modules/websocket/websocket_server.h
+++ b/modules/websocket/websocket_server.h
@@ -51,6 +51,7 @@ protected:
uint32_t handshake_timeout = 3000;
public:
+ virtual void set_extra_headers(const Vector<String> &p_headers) = 0;
virtual Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) = 0;
virtual void stop() = 0;
virtual bool is_listening() const = 0;
diff --git a/modules/websocket/wsl_client.cpp b/modules/websocket/wsl_client.cpp
index 1ef571b6ee..894ba7766f 100644
--- a/modules/websocket/wsl_client.cpp
+++ b/modules/websocket/wsl_client.cpp
@@ -91,6 +91,7 @@ void WSLClient::_do_handshake() {
data->id = 1;
_peer->make_context(data, _in_buf_size, _in_pkt_size, _out_buf_size, _out_pkt_size);
_peer->set_no_delay(true);
+ _status = CONNECTION_CONNECTED;
_on_connect(protocol);
break;
}
@@ -103,13 +104,14 @@ bool WSLClient::_verify_headers(String &r_protocol) {
String s = (char *)_resp_buf;
Vector<String> psa = s.split("\r\n");
int len = psa.size();
- ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers, got: " + itos(len) + ", expected >= 4.");
+ ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers. Got: " + itos(len) + ", expected >= 4.");
Vector<String> req = psa[0].split(" ", false);
- ERR_FAIL_COND_V_MSG(req.size() < 2, false, "Invalid protocol or status code.");
+ ERR_FAIL_COND_V_MSG(req.size() < 2, false, "Invalid protocol or status code. Got '" + psa[0] + "', expected 'HTTP/1.1 101'.");
// Wrong protocol
- ERR_FAIL_COND_V_MSG(req[0] != "HTTP/1.1" || req[1] != "101", false, "Invalid protocol or status code.");
+ ERR_FAIL_COND_V_MSG(req[0] != "HTTP/1.1", false, "Invalid protocol. Got: '" + req[0] + "', expected 'HTTP/1.1'.");
+ ERR_FAIL_COND_V_MSG(req[1] != "101", false, "Invalid status code. Got: '" + req[1] + "', expected '101'.");
Map<String, String> headers;
for (int i = 1; i < len; i++) {
@@ -137,9 +139,11 @@ bool WSLClient::_verify_headers(String &r_protocol) {
#undef WSL_CHECK
if (_protocols.size() == 0) {
// We didn't request a custom protocol
- ERR_FAIL_COND_V(headers.has("sec-websocket-protocol"), false);
+ ERR_FAIL_COND_V_MSG(headers.has("sec-websocket-protocol"), false, "Received unrequested sub-protocol -> " + headers["sec-websocket-protocol"]);
} else {
- ERR_FAIL_COND_V(!headers.has("sec-websocket-protocol"), false);
+ // We requested at least one custom protocol but didn't receive one
+ ERR_FAIL_COND_V_MSG(!headers.has("sec-websocket-protocol"), false, "Requested sub-protocol(s) but received none.");
+ // Check received sub-protocol was one of those requested.
r_protocol = headers["sec-websocket-protocol"];
bool valid = false;
for (int i = 0; i < _protocols.size(); i++) {
@@ -150,6 +154,7 @@ bool WSLClient::_verify_headers(String &r_protocol) {
break;
}
if (!valid) {
+ ERR_FAIL_V_MSG(false, "Received unrequested sub-protocol -> " + r_protocol);
return false;
}
}
@@ -227,6 +232,7 @@ Error WSLClient::connect_to_host(String p_host, String p_path, uint16_t p_port,
}
request += "\r\n";
_request = request.utf8();
+ _status = CONNECTION_CONNECTING;
return OK;
}
@@ -273,6 +279,7 @@ void WSLClient::poll() {
return; // Not connected.
}
+ _tcp->poll();
switch (_tcp->get_status()) {
case StreamPeerTCP::STATUS_NONE:
// Clean close
@@ -332,21 +339,19 @@ Ref<WebSocketPeer> WSLClient::get_peer(int p_peer_id) const {
}
MultiplayerPeer::ConnectionStatus WSLClient::get_connection_status() const {
+ // This is surprising, but keeps the current behaviour to allow clean close requests.
+ // TODO Refactor WebSocket and split Client/Server/Multiplayer like done in other peers.
if (_peer->is_connected_to_host()) {
return CONNECTION_CONNECTED;
}
-
- if (_tcp->is_connected_to_host() || _resolver_id != IP::RESOLVER_INVALID_ID) {
- return CONNECTION_CONNECTING;
- }
-
- return CONNECTION_DISCONNECTED;
+ return _status;
}
void WSLClient::disconnect_from_host(int p_code, String p_reason) {
_peer->close(p_code, p_reason);
_connection = Ref<StreamPeer>(nullptr);
_tcp = Ref<StreamPeerTCP>(memnew(StreamPeerTCP));
+ _status = CONNECTION_DISCONNECTED;
_key = "";
_host = "";
diff --git a/modules/websocket/wsl_client.h b/modules/websocket/wsl_client.h
index d846e6be00..22d7ffa839 100644
--- a/modules/websocket/wsl_client.h
+++ b/modules/websocket/wsl_client.h
@@ -52,6 +52,7 @@ private:
Ref<WSLPeer> _peer;
Ref<StreamPeerTCP> _tcp;
Ref<StreamPeer> _connection;
+ ConnectionStatus _status = CONNECTION_DISCONNECTED;
CharString _request;
int _requested = 0;
@@ -59,11 +60,9 @@ private:
uint8_t _resp_buf[WSL_MAX_HEADER_SIZE];
int _resp_pos = 0;
- String _response;
-
String _key;
String _host;
- uint16_t _port;
+ uint16_t _port = 0;
Array _ip_candidates;
Vector<String> _protocols;
bool _use_ssl = false;
@@ -73,15 +72,15 @@ private:
bool _verify_headers(String &r_protocol);
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>());
- int get_max_packet_size() const;
- Ref<WebSocketPeer> get_peer(int p_peer_id) const;
- void disconnect_from_host(int p_code = 1000, String p_reason = "");
- IPAddress get_connected_host() const;
- uint16_t get_connected_port() const;
- virtual ConnectionStatus get_connection_status() const;
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()) override;
+ int get_max_packet_size() const override;
+ Ref<WebSocketPeer> get_peer(int p_peer_id) const override;
+ void disconnect_from_host(int p_code = 1000, String p_reason = "") override;
+ IPAddress get_connected_host() const override;
+ uint16_t get_connected_port() const override;
+ virtual ConnectionStatus get_connection_status() const override;
+ virtual void poll() override;
WSLClient();
~WSLClient();
diff --git a/modules/websocket/wsl_peer.cpp b/modules/websocket/wsl_peer.cpp
index d277eedace..15df4d039c 100644
--- a/modules/websocket/wsl_peer.cpp
+++ b/modules/websocket/wsl_peer.cpp
@@ -146,17 +146,17 @@ void wsl_msg_recv_callback(wslay_event_context_ptr ctx, const struct wslay_event
if (!peer_data->valid || peer_data->closing) {
return;
}
- WSLPeer *peer = (WSLPeer *)peer_data->peer;
+ WSLPeer *peer = static_cast<WSLPeer *>(peer_data->peer);
if (peer->parse_message(arg) != OK) {
return;
}
if (peer_data->is_server) {
- WSLServer *helper = (WSLServer *)peer_data->obj;
+ WSLServer *helper = static_cast<WSLServer *>(peer_data->obj);
helper->_on_peer_packet(peer_data->id);
} else {
- WSLClient *helper = (WSLClient *)peer_data->obj;
+ WSLClient *helper = static_cast<WSLClient *>(peer_data->obj);
helper->_on_peer_packet();
}
}
@@ -184,10 +184,10 @@ Error WSLPeer::parse_message(const wslay_event_on_msg_recv_arg *arg) {
}
if (!wslay_event_get_close_sent(_data->ctx)) {
if (_data->is_server) {
- WSLServer *helper = (WSLServer *)_data->obj;
+ WSLServer *helper = static_cast<WSLServer *>(_data->obj);
helper->_on_close_request(_data->id, close_code, close_reason);
} else {
- WSLClient *helper = (WSLClient *)_data->obj;
+ WSLClient *helper = static_cast<WSLClient *>(_data->obj);
helper->_on_close_request(close_code, close_reason);
}
}
diff --git a/modules/websocket/wsl_peer.h b/modules/websocket/wsl_peer.h
index 555559c6e1..abeecdd537 100644
--- a/modules/websocket/wsl_peer.h
+++ b/modules/websocket/wsl_peer.h
@@ -85,22 +85,22 @@ public:
String close_reason;
void poll(); // Used by client and server.
- virtual int get_available_packet_count() const;
- virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size);
- virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size);
- virtual int get_max_packet_size() const { return _packet_buffer.size(); };
- virtual int get_current_outbound_buffered_amount() const;
+ virtual int get_available_packet_count() const override;
+ virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size) override;
+ virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size) override;
+ virtual int get_max_packet_size() const override { return _packet_buffer.size(); };
+ virtual int get_current_outbound_buffered_amount() const override;
virtual void close_now();
- virtual void close(int p_code = 1000, String p_reason = "");
- virtual bool is_connected_to_host() const;
- virtual IPAddress get_connected_host() const;
- virtual uint16_t get_connected_port() const;
-
- virtual WriteMode get_write_mode() const;
- virtual void set_write_mode(WriteMode p_mode);
- virtual bool was_string_packet() const;
- virtual void set_no_delay(bool p_enabled);
+ virtual void close(int p_code = 1000, String p_reason = "") override;
+ virtual bool is_connected_to_host() const override;
+ virtual IPAddress get_connected_host() const override;
+ virtual uint16_t get_connected_port() const override;
+
+ virtual WriteMode get_write_mode() const override;
+ virtual void set_write_mode(WriteMode p_mode) override;
+ virtual bool was_string_packet() const override;
+ virtual void set_no_delay(bool p_enabled) override;
void make_context(PeerData *p_data, unsigned int p_in_buf_size, unsigned int p_in_pkt_size, unsigned int p_out_buf_size, unsigned int p_out_pkt_size);
Error parse_message(const wslay_event_on_msg_recv_arg *arg);
diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp
index eadd7ef7ac..b58b2e4724 100644
--- a/modules/websocket/wsl_server.cpp
+++ b/modules/websocket/wsl_server.cpp
@@ -96,7 +96,7 @@ bool WSLServer::PendingPeer::_parse_request(const Vector<String> p_protocols, St
return true;
}
-Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name) {
+Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector<String> &p_extra_headers) {
if (OS::get_singleton()->get_ticks_msec() - time > p_timeout) {
print_verbose(vformat("WebSocket handshake timed out after %.3f seconds.", p_timeout * 0.001));
return ERR_TIMEOUT;
@@ -141,6 +141,9 @@ Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uin
if (!protocol.is_empty()) {
s += "Sec-WebSocket-Protocol: " + protocol + "\r\n";
}
+ for (int i = 0; i < p_extra_headers.size(); i++) {
+ s += p_extra_headers[i] + "\r\n";
+ }
s += "\r\n";
response = s.utf8();
has_request = true;
@@ -167,6 +170,10 @@ Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uin
return OK;
}
+void WSLServer::set_extra_headers(const Vector<String> &p_headers) {
+ _extra_headers = p_headers;
+}
+
Error WSLServer::listen(int p_port, const Vector<String> p_protocols, bool gd_mp_api) {
ERR_FAIL_COND_V(is_listening(), ERR_ALREADY_IN_USE);
@@ -183,7 +190,7 @@ Error WSLServer::listen(int p_port, const Vector<String> p_protocols, bool gd_mp
void WSLServer::poll() {
List<int> remove_ids;
for (const KeyValue<int, Ref<WebSocketPeer>> &E : _peer_map) {
- Ref<WSLPeer> peer = (WSLPeer *)E.value.ptr();
+ Ref<WSLPeer> peer = const_cast<WSLPeer *>(static_cast<const WSLPeer *>(E.value.ptr()));
peer->poll();
if (!peer->is_connected_to_host()) {
_on_disconnect(E.key, peer->close_code != -1);
@@ -199,7 +206,7 @@ void WSLServer::poll() {
for (const Ref<PendingPeer> &E : _pending) {
String resource_name;
Ref<PendingPeer> ppeer = E;
- Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name);
+ Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name, _extra_headers);
if (err == ERR_BUSY) {
continue;
} else if (err != OK) {
@@ -266,7 +273,7 @@ int WSLServer::get_max_packet_size() const {
void WSLServer::stop() {
_server->stop();
for (const KeyValue<int, Ref<WebSocketPeer>> &E : _peer_map) {
- Ref<WSLPeer> peer = (WSLPeer *)E.value.ptr();
+ Ref<WSLPeer> peer = const_cast<WSLPeer *>(static_cast<const WSLPeer *>(E.value.ptr()));
peer->close_now();
}
_pending.clear();
diff --git a/modules/websocket/wsl_server.h b/modules/websocket/wsl_server.h
index 221cae4793..a920e9c665 100644
--- a/modules/websocket/wsl_server.h
+++ b/modules/websocket/wsl_server.h
@@ -62,7 +62,7 @@ private:
CharString response;
int response_sent = 0;
- Error do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name);
+ Error do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector<String> &p_extra_headers);
};
int _in_buf_size = DEF_BUF_SHIFT;
@@ -73,19 +73,21 @@ private:
List<Ref<PendingPeer>> _pending;
Ref<TCPServer> _server;
Vector<String> _protocols;
+ Vector<String> _extra_headers;
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false);
- void stop();
- bool is_listening() const;
- int get_max_packet_size() const;
- bool has_peer(int p_id) const;
- Ref<WebSocketPeer> get_peer(int p_id) const;
- IPAddress get_peer_address(int p_peer_id) const;
- int get_peer_port(int p_peer_id) const;
- void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "");
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ void set_extra_headers(const Vector<String> &p_headers) override;
+ Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) override;
+ void stop() override;
+ bool is_listening() const override;
+ int get_max_packet_size() const override;
+ bool has_peer(int p_id) const override;
+ Ref<WebSocketPeer> get_peer(int p_id) const override;
+ IPAddress get_peer_address(int p_peer_id) const override;
+ int get_peer_port(int p_peer_id) const override;
+ void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "") override;
+ virtual void poll() override;
WSLServer();
~WSLServer();