summaryrefslogtreecommitdiff
path: root/modules/websocket
diff options
context:
space:
mode:
Diffstat (limited to 'modules/websocket')
-rw-r--r--modules/websocket/doc_classes/WebSocketServer.xml7
-rw-r--r--modules/websocket/emws_client.h18
-rw-r--r--modules/websocket/emws_peer.h26
-rw-r--r--modules/websocket/emws_server.cpp3
-rw-r--r--modules/websocket/emws_server.h23
-rw-r--r--modules/websocket/remote_debugger_peer_websocket.h16
-rw-r--r--modules/websocket/websocket_server.cpp1
-rw-r--r--modules/websocket/websocket_server.h1
-rw-r--r--modules/websocket/wsl_client.cpp3
-rw-r--r--modules/websocket/wsl_client.h18
-rw-r--r--modules/websocket/wsl_peer.cpp10
-rw-r--r--modules/websocket/wsl_peer.h28
-rw-r--r--modules/websocket/wsl_server.cpp15
-rw-r--r--modules/websocket/wsl_server.h26
14 files changed, 109 insertions, 86 deletions
diff --git a/modules/websocket/doc_classes/WebSocketServer.xml b/modules/websocket/doc_classes/WebSocketServer.xml
index ef3279aac4..46b0274de3 100644
--- a/modules/websocket/doc_classes/WebSocketServer.xml
+++ b/modules/websocket/doc_classes/WebSocketServer.xml
@@ -60,6 +60,13 @@
If [code]false[/code] is passed instead (default), you must call [PacketPeer] functions ([code]put_packet[/code], [code]get_packet[/code], etc.), on the [WebSocketPeer] returned via [code]get_peer(id)[/code] to communicate with the peer with given [code]id[/code] (e.g. [code]get_peer(id).get_available_packet_count[/code]).
</description>
</method>
+ <method name="set_extra_headers">
+ <return type="void" />
+ <argument index="0" name="headers" type="PackedStringArray" default="PackedStringArray()" />
+ <description>
+ Sets additional headers to be sent to clients during the HTTP handshake.
+ </description>
+ </method>
<method name="stop">
<return type="void" />
<description>
diff --git a/modules/websocket/emws_client.h b/modules/websocket/emws_client.h
index 61ea0002ea..ca327a56fa 100644
--- a/modules/websocket/emws_client.h
+++ b/modules/websocket/emws_client.h
@@ -53,15 +53,15 @@ private:
static void _esws_on_close(void *obj, int code, const char *reason, int was_clean);
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>());
- Ref<WebSocketPeer> get_peer(int p_peer_id) const;
- void disconnect_from_host(int p_code = 1000, String p_reason = "");
- IPAddress get_connected_host() const;
- uint16_t get_connected_port() const;
- virtual ConnectionStatus get_connection_status() const;
- int get_max_packet_size() const;
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()) override;
+ Ref<WebSocketPeer> get_peer(int p_peer_id) const override;
+ void disconnect_from_host(int p_code = 1000, String p_reason = "") override;
+ IPAddress get_connected_host() const override;
+ uint16_t get_connected_port() const override;
+ virtual ConnectionStatus get_connection_status() const override;
+ int get_max_packet_size() const override;
+ virtual void poll() override;
EMWSClient();
~EMWSClient();
};
diff --git a/modules/websocket/emws_peer.h b/modules/websocket/emws_peer.h
index df63d2d801..6bb4552c37 100644
--- a/modules/websocket/emws_peer.h
+++ b/modules/websocket/emws_peer.h
@@ -68,21 +68,21 @@ private:
public:
Error read_msg(const uint8_t *p_data, uint32_t p_size, bool p_is_string);
void set_sock(int p_sock, unsigned int p_in_buf_size, unsigned int p_in_pkt_size, unsigned int p_out_buf_size);
- virtual int get_available_packet_count() const;
- virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size);
- virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size);
- virtual int get_max_packet_size() const { return _packet_buffer.size(); };
- virtual int get_current_outbound_buffered_amount() const;
+ virtual int get_available_packet_count() const override;
+ virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size) override;
+ virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size) override;
+ virtual int get_max_packet_size() const override { return _packet_buffer.size(); };
+ virtual int get_current_outbound_buffered_amount() const override;
- virtual void close(int p_code = 1000, String p_reason = "");
- virtual bool is_connected_to_host() const;
- virtual IPAddress get_connected_host() const;
- virtual uint16_t get_connected_port() const;
+ virtual void close(int p_code = 1000, String p_reason = "") override;
+ virtual bool is_connected_to_host() const override;
+ virtual IPAddress get_connected_host() const override;
+ virtual uint16_t get_connected_port() const override;
- virtual WriteMode get_write_mode() const;
- virtual void set_write_mode(WriteMode p_mode);
- virtual bool was_string_packet() const;
- virtual void set_no_delay(bool p_enabled);
+ virtual WriteMode get_write_mode() const override;
+ virtual void set_write_mode(WriteMode p_mode) override;
+ virtual bool was_string_packet() const override;
+ virtual void set_no_delay(bool p_enabled) override;
EMWSPeer();
~EMWSPeer();
diff --git a/modules/websocket/emws_server.cpp b/modules/websocket/emws_server.cpp
index 53b4a0207d..2033098cad 100644
--- a/modules/websocket/emws_server.cpp
+++ b/modules/websocket/emws_server.cpp
@@ -33,6 +33,9 @@
#include "emws_server.h"
#include "core/os/os.h"
+void EMWSServer::set_extra_headers(const Vector<String> &p_headers) {
+}
+
Error EMWSServer::listen(int p_port, Vector<String> p_protocols, bool gd_mp_api) {
return FAILED;
}
diff --git a/modules/websocket/emws_server.h b/modules/websocket/emws_server.h
index f310c17c9d..ae31d9dbb0 100644
--- a/modules/websocket/emws_server.h
+++ b/modules/websocket/emws_server.h
@@ -41,17 +41,18 @@ class EMWSServer : public WebSocketServer {
GDCIIMPL(EMWSServer, WebSocketServer);
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error listen(int p_port, Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false);
- void stop();
- bool is_listening() const;
- bool has_peer(int p_id) const;
- Ref<WebSocketPeer> get_peer(int p_id) const;
- IPAddress get_peer_address(int p_peer_id) const;
- int get_peer_port(int p_peer_id) const;
- void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "");
- int get_max_packet_size() const;
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ void set_extra_headers(const Vector<String> &p_headers) override;
+ Error listen(int p_port, Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) override;
+ void stop() override;
+ bool is_listening() const override;
+ bool has_peer(int p_id) const override;
+ Ref<WebSocketPeer> get_peer(int p_id) const override;
+ IPAddress get_peer_address(int p_peer_id) const override;
+ int get_peer_port(int p_peer_id) const override;
+ void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "") override;
+ int get_max_packet_size() const override;
+ virtual void poll() override;
virtual Vector<String> get_protocols() const;
EMWSServer();
diff --git a/modules/websocket/remote_debugger_peer_websocket.h b/modules/websocket/remote_debugger_peer_websocket.h
index 84f9506625..3227065ded 100644
--- a/modules/websocket/remote_debugger_peer_websocket.h
+++ b/modules/websocket/remote_debugger_peer_websocket.h
@@ -51,14 +51,14 @@ public:
static RemoteDebuggerPeer *create(const String &p_uri);
Error connect_to_host(const String &p_uri);
- bool is_peer_connected();
- int get_max_message_size() const;
- bool has_message();
- Error put_message(const Array &p_arr);
- Array get_message();
- void close();
- void poll();
- bool can_block() const;
+ bool is_peer_connected() override;
+ int get_max_message_size() const override;
+ bool has_message() override;
+ Error put_message(const Array &p_arr) override;
+ Array get_message() override;
+ void close() override;
+ void poll() override;
+ bool can_block() const override;
RemoteDebuggerPeerWebSocket(Ref<WebSocketPeer> p_peer = Ref<WebSocketPeer>());
};
diff --git a/modules/websocket/websocket_server.cpp b/modules/websocket/websocket_server.cpp
index b3f0140b80..b7851b02c4 100644
--- a/modules/websocket/websocket_server.cpp
+++ b/modules/websocket/websocket_server.cpp
@@ -42,6 +42,7 @@ WebSocketServer::~WebSocketServer() {
void WebSocketServer::_bind_methods() {
ClassDB::bind_method(D_METHOD("is_listening"), &WebSocketServer::is_listening);
+ ClassDB::bind_method(D_METHOD("set_extra_headers", "headers"), &WebSocketServer::set_extra_headers, DEFVAL(Vector<String>()));
ClassDB::bind_method(D_METHOD("listen", "port", "protocols", "gd_mp_api"), &WebSocketServer::listen, DEFVAL(Vector<String>()), DEFVAL(false));
ClassDB::bind_method(D_METHOD("stop"), &WebSocketServer::stop);
ClassDB::bind_method(D_METHOD("has_peer", "id"), &WebSocketServer::has_peer);
diff --git a/modules/websocket/websocket_server.h b/modules/websocket/websocket_server.h
index f6f3b80045..7bd80851f5 100644
--- a/modules/websocket/websocket_server.h
+++ b/modules/websocket/websocket_server.h
@@ -51,6 +51,7 @@ protected:
uint32_t handshake_timeout = 3000;
public:
+ virtual void set_extra_headers(const Vector<String> &p_headers) = 0;
virtual Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) = 0;
virtual void stop() = 0;
virtual bool is_listening() const = 0;
diff --git a/modules/websocket/wsl_client.cpp b/modules/websocket/wsl_client.cpp
index 1ef571b6ee..58c329f043 100644
--- a/modules/websocket/wsl_client.cpp
+++ b/modules/websocket/wsl_client.cpp
@@ -273,6 +273,7 @@ void WSLClient::poll() {
return; // Not connected.
}
+ _tcp->poll();
switch (_tcp->get_status()) {
case StreamPeerTCP::STATUS_NONE:
// Clean close
@@ -336,7 +337,7 @@ MultiplayerPeer::ConnectionStatus WSLClient::get_connection_status() const {
return CONNECTION_CONNECTED;
}
- if (_tcp->is_connected_to_host() || _resolver_id != IP::RESOLVER_INVALID_ID) {
+ if (_tcp->get_status() == StreamPeerTCP::STATUS_CONNECTING || _resolver_id != IP::RESOLVER_INVALID_ID) {
return CONNECTION_CONNECTING;
}
diff --git a/modules/websocket/wsl_client.h b/modules/websocket/wsl_client.h
index d846e6be00..22b3a4f373 100644
--- a/modules/websocket/wsl_client.h
+++ b/modules/websocket/wsl_client.h
@@ -73,15 +73,15 @@ private:
bool _verify_headers(String &r_protocol);
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>());
- int get_max_packet_size() const;
- Ref<WebSocketPeer> get_peer(int p_peer_id) const;
- void disconnect_from_host(int p_code = 1000, String p_reason = "");
- IPAddress get_connected_host() const;
- uint16_t get_connected_port() const;
- virtual ConnectionStatus get_connection_status() const;
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocol = Vector<String>(), const Vector<String> p_custom_headers = Vector<String>()) override;
+ int get_max_packet_size() const override;
+ Ref<WebSocketPeer> get_peer(int p_peer_id) const override;
+ void disconnect_from_host(int p_code = 1000, String p_reason = "") override;
+ IPAddress get_connected_host() const override;
+ uint16_t get_connected_port() const override;
+ virtual ConnectionStatus get_connection_status() const override;
+ virtual void poll() override;
WSLClient();
~WSLClient();
diff --git a/modules/websocket/wsl_peer.cpp b/modules/websocket/wsl_peer.cpp
index d277eedace..15df4d039c 100644
--- a/modules/websocket/wsl_peer.cpp
+++ b/modules/websocket/wsl_peer.cpp
@@ -146,17 +146,17 @@ void wsl_msg_recv_callback(wslay_event_context_ptr ctx, const struct wslay_event
if (!peer_data->valid || peer_data->closing) {
return;
}
- WSLPeer *peer = (WSLPeer *)peer_data->peer;
+ WSLPeer *peer = static_cast<WSLPeer *>(peer_data->peer);
if (peer->parse_message(arg) != OK) {
return;
}
if (peer_data->is_server) {
- WSLServer *helper = (WSLServer *)peer_data->obj;
+ WSLServer *helper = static_cast<WSLServer *>(peer_data->obj);
helper->_on_peer_packet(peer_data->id);
} else {
- WSLClient *helper = (WSLClient *)peer_data->obj;
+ WSLClient *helper = static_cast<WSLClient *>(peer_data->obj);
helper->_on_peer_packet();
}
}
@@ -184,10 +184,10 @@ Error WSLPeer::parse_message(const wslay_event_on_msg_recv_arg *arg) {
}
if (!wslay_event_get_close_sent(_data->ctx)) {
if (_data->is_server) {
- WSLServer *helper = (WSLServer *)_data->obj;
+ WSLServer *helper = static_cast<WSLServer *>(_data->obj);
helper->_on_close_request(_data->id, close_code, close_reason);
} else {
- WSLClient *helper = (WSLClient *)_data->obj;
+ WSLClient *helper = static_cast<WSLClient *>(_data->obj);
helper->_on_close_request(close_code, close_reason);
}
}
diff --git a/modules/websocket/wsl_peer.h b/modules/websocket/wsl_peer.h
index 555559c6e1..abeecdd537 100644
--- a/modules/websocket/wsl_peer.h
+++ b/modules/websocket/wsl_peer.h
@@ -85,22 +85,22 @@ public:
String close_reason;
void poll(); // Used by client and server.
- virtual int get_available_packet_count() const;
- virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size);
- virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size);
- virtual int get_max_packet_size() const { return _packet_buffer.size(); };
- virtual int get_current_outbound_buffered_amount() const;
+ virtual int get_available_packet_count() const override;
+ virtual Error get_packet(const uint8_t **r_buffer, int &r_buffer_size) override;
+ virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size) override;
+ virtual int get_max_packet_size() const override { return _packet_buffer.size(); };
+ virtual int get_current_outbound_buffered_amount() const override;
virtual void close_now();
- virtual void close(int p_code = 1000, String p_reason = "");
- virtual bool is_connected_to_host() const;
- virtual IPAddress get_connected_host() const;
- virtual uint16_t get_connected_port() const;
-
- virtual WriteMode get_write_mode() const;
- virtual void set_write_mode(WriteMode p_mode);
- virtual bool was_string_packet() const;
- virtual void set_no_delay(bool p_enabled);
+ virtual void close(int p_code = 1000, String p_reason = "") override;
+ virtual bool is_connected_to_host() const override;
+ virtual IPAddress get_connected_host() const override;
+ virtual uint16_t get_connected_port() const override;
+
+ virtual WriteMode get_write_mode() const override;
+ virtual void set_write_mode(WriteMode p_mode) override;
+ virtual bool was_string_packet() const override;
+ virtual void set_no_delay(bool p_enabled) override;
void make_context(PeerData *p_data, unsigned int p_in_buf_size, unsigned int p_in_pkt_size, unsigned int p_out_buf_size, unsigned int p_out_pkt_size);
Error parse_message(const wslay_event_on_msg_recv_arg *arg);
diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp
index eadd7ef7ac..b58b2e4724 100644
--- a/modules/websocket/wsl_server.cpp
+++ b/modules/websocket/wsl_server.cpp
@@ -96,7 +96,7 @@ bool WSLServer::PendingPeer::_parse_request(const Vector<String> p_protocols, St
return true;
}
-Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name) {
+Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector<String> &p_extra_headers) {
if (OS::get_singleton()->get_ticks_msec() - time > p_timeout) {
print_verbose(vformat("WebSocket handshake timed out after %.3f seconds.", p_timeout * 0.001));
return ERR_TIMEOUT;
@@ -141,6 +141,9 @@ Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uin
if (!protocol.is_empty()) {
s += "Sec-WebSocket-Protocol: " + protocol + "\r\n";
}
+ for (int i = 0; i < p_extra_headers.size(); i++) {
+ s += p_extra_headers[i] + "\r\n";
+ }
s += "\r\n";
response = s.utf8();
has_request = true;
@@ -167,6 +170,10 @@ Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols, uin
return OK;
}
+void WSLServer::set_extra_headers(const Vector<String> &p_headers) {
+ _extra_headers = p_headers;
+}
+
Error WSLServer::listen(int p_port, const Vector<String> p_protocols, bool gd_mp_api) {
ERR_FAIL_COND_V(is_listening(), ERR_ALREADY_IN_USE);
@@ -183,7 +190,7 @@ Error WSLServer::listen(int p_port, const Vector<String> p_protocols, bool gd_mp
void WSLServer::poll() {
List<int> remove_ids;
for (const KeyValue<int, Ref<WebSocketPeer>> &E : _peer_map) {
- Ref<WSLPeer> peer = (WSLPeer *)E.value.ptr();
+ Ref<WSLPeer> peer = const_cast<WSLPeer *>(static_cast<const WSLPeer *>(E.value.ptr()));
peer->poll();
if (!peer->is_connected_to_host()) {
_on_disconnect(E.key, peer->close_code != -1);
@@ -199,7 +206,7 @@ void WSLServer::poll() {
for (const Ref<PendingPeer> &E : _pending) {
String resource_name;
Ref<PendingPeer> ppeer = E;
- Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name);
+ Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name, _extra_headers);
if (err == ERR_BUSY) {
continue;
} else if (err != OK) {
@@ -266,7 +273,7 @@ int WSLServer::get_max_packet_size() const {
void WSLServer::stop() {
_server->stop();
for (const KeyValue<int, Ref<WebSocketPeer>> &E : _peer_map) {
- Ref<WSLPeer> peer = (WSLPeer *)E.value.ptr();
+ Ref<WSLPeer> peer = const_cast<WSLPeer *>(static_cast<const WSLPeer *>(E.value.ptr()));
peer->close_now();
}
_pending.clear();
diff --git a/modules/websocket/wsl_server.h b/modules/websocket/wsl_server.h
index 221cae4793..a920e9c665 100644
--- a/modules/websocket/wsl_server.h
+++ b/modules/websocket/wsl_server.h
@@ -62,7 +62,7 @@ private:
CharString response;
int response_sent = 0;
- Error do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name);
+ Error do_handshake(const Vector<String> p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector<String> &p_extra_headers);
};
int _in_buf_size = DEF_BUF_SHIFT;
@@ -73,19 +73,21 @@ private:
List<Ref<PendingPeer>> _pending;
Ref<TCPServer> _server;
Vector<String> _protocols;
+ Vector<String> _extra_headers;
public:
- Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets);
- Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false);
- void stop();
- bool is_listening() const;
- int get_max_packet_size() const;
- bool has_peer(int p_id) const;
- Ref<WebSocketPeer> get_peer(int p_id) const;
- IPAddress get_peer_address(int p_peer_id) const;
- int get_peer_port(int p_peer_id) const;
- void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "");
- virtual void poll();
+ Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override;
+ void set_extra_headers(const Vector<String> &p_headers) override;
+ Error listen(int p_port, const Vector<String> p_protocols = Vector<String>(), bool gd_mp_api = false) override;
+ void stop() override;
+ bool is_listening() const override;
+ int get_max_packet_size() const override;
+ bool has_peer(int p_id) const override;
+ Ref<WebSocketPeer> get_peer(int p_id) const override;
+ IPAddress get_peer_address(int p_peer_id) const override;
+ int get_peer_port(int p_peer_id) const override;
+ void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "") override;
+ virtual void poll() override;
WSLServer();
~WSLServer();