summaryrefslogtreecommitdiff
path: root/modules/websocket/wsl_server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/websocket/wsl_server.cpp')
-rw-r--r--modules/websocket/wsl_server.cpp44
1 files changed, 36 insertions, 8 deletions
diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp
index efb526eed1..2181775b99 100644
--- a/modules/websocket/wsl_server.cpp
+++ b/modules/websocket/wsl_server.cpp
@@ -35,6 +35,7 @@
#include "core/project_settings.h"
WSLServer::PendingPeer::PendingPeer() {
+ use_ssl = false;
time = 0;
has_request = false;
response_sent = 0;
@@ -42,7 +43,7 @@ WSLServer::PendingPeer::PendingPeer() {
memset(req_buf, 0, sizeof(req_buf));
}
-bool WSLServer::PendingPeer::_parse_request(const PoolStringArray p_protocols) {
+bool WSLServer::PendingPeer::_parse_request(const Vector<String> p_protocols) {
Vector<String> psa = String((char *)req_buf).split("\r\n");
int len = psa.size();
ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers, got: " + itos(len) + ", expected >= 4.");
@@ -79,11 +80,12 @@ bool WSLServer::PendingPeer::_parse_request(const PoolStringArray p_protocols) {
if (headers.has("sec-websocket-protocol")) {
Vector<String> protos = headers["sec-websocket-protocol"].split(",");
for (int i = 0; i < protos.size(); i++) {
+ String proto = protos[i].strip_edges();
// Check if we have the given protocol
for (int j = 0; j < p_protocols.size(); j++) {
- if (protos[i] != p_protocols[j])
+ if (proto != p_protocols[j])
continue;
- protocol = protos[i];
+ protocol = proto;
break;
}
// Found a protocol
@@ -97,9 +99,19 @@ bool WSLServer::PendingPeer::_parse_request(const PoolStringArray p_protocols) {
return true;
}
-Error WSLServer::PendingPeer::do_handshake(PoolStringArray p_protocols) {
+Error WSLServer::PendingPeer::do_handshake(const Vector<String> p_protocols) {
if (OS::get_singleton()->get_ticks_msec() - time > WSL_SERVER_TIMEOUT)
return ERR_TIMEOUT;
+ if (use_ssl) {
+ Ref<StreamPeerSSL> ssl = static_cast<Ref<StreamPeerSSL> >(connection);
+ if (ssl.is_null())
+ return FAILED;
+ ssl->poll();
+ if (ssl->get_status() == StreamPeerSSL::STATUS_HANDSHAKING)
+ return ERR_BUSY;
+ else if (ssl->get_status() != StreamPeerSSL::STATUS_CONNECTED)
+ return FAILED;
+ }
if (!has_request) {
int read = 0;
while (true) {
@@ -143,11 +155,16 @@ Error WSLServer::PendingPeer::do_handshake(PoolStringArray p_protocols) {
return OK;
}
-Error WSLServer::listen(int p_port, PoolVector<String> p_protocols, bool gd_mp_api) {
+Error WSLServer::listen(int p_port, const Vector<String> p_protocols, bool gd_mp_api) {
ERR_FAIL_COND_V(is_listening(), ERR_ALREADY_IN_USE);
_is_multiplayer = gd_mp_api;
- _protocols = p_protocols;
+ // Strip edges from protocols.
+ _protocols.resize(p_protocols.size());
+ String *pw = _protocols.ptrw();
+ for (int i = 0; i < p_protocols.size(); i++) {
+ pw[i] = p_protocols[i].strip_edges();
+ }
_server->listen(p_port);
return OK;
@@ -185,6 +202,7 @@ void WSLServer::poll() {
WSLPeer::PeerData *data = memnew(struct WSLPeer::PeerData);
data->obj = this;
data->conn = ppeer->connection;
+ data->tcp = ppeer->tcp;
data->is_server = true;
data->id = id;
@@ -204,12 +222,21 @@ void WSLServer::poll() {
return;
while (_server->is_connection_available()) {
- Ref<StreamPeer> conn = _server->take_connection();
+ Ref<StreamPeerTCP> conn = _server->take_connection();
if (is_refusing_new_connections())
continue; // Conn will go out-of-scope and be closed.
Ref<PendingPeer> peer = memnew(PendingPeer);
- peer->connection = conn;
+ if (private_key.is_valid() && ssl_cert.is_valid()) {
+ Ref<StreamPeerSSL> ssl = Ref<StreamPeerSSL>(StreamPeerSSL::create());
+ ssl->set_blocking_handshake_enabled(false);
+ ssl->accept_stream(conn, private_key, ssl_cert, ca_chain);
+ peer->connection = ssl;
+ peer->use_ssl = true;
+ } else {
+ peer->connection = conn;
+ }
+ peer->tcp = conn;
peer->time = OS::get_singleton()->get_ticks_msec();
_pending.push_back(peer);
}
@@ -231,6 +258,7 @@ void WSLServer::stop() {
}
_pending.clear();
_peer_map.clear();
+ _protocols.clear();
}
bool WSLServer::has_peer(int p_id) const {