summaryrefslogtreecommitdiff
path: root/modules/websocket/wsl_peer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/websocket/wsl_peer.cpp')
-rw-r--r--modules/websocket/wsl_peer.cpp849
1 files changed, 672 insertions, 177 deletions
diff --git a/modules/websocket/wsl_peer.cpp b/modules/websocket/wsl_peer.cpp
index 97bd87a526..84e022182e 100644
--- a/modules/websocket/wsl_peer.cpp
+++ b/modules/websocket/wsl_peer.cpp
@@ -32,71 +32,537 @@
#include "wsl_peer.h"
-#include "wsl_client.h"
-#include "wsl_server.h"
+#include "wsl_peer.h"
-#include "core/crypto/crypto_core.h"
-#include "core/math/random_number_generator.h"
-#include "core/os/os.h"
+#include "core/io/stream_peer_tls.h"
-String WSLPeer::generate_key() {
- // Random key
- RandomNumberGenerator rng;
- rng.set_seed(OS::get_singleton()->get_unix_time());
- Vector<uint8_t> bkey;
- int len = 16; // 16 bytes, as per RFC
- bkey.resize(len);
- uint8_t *w = bkey.ptrw();
- for (int i = 0; i < len; i++) {
- w[i] = (uint8_t)rng.randi_range(0, 255);
+CryptoCore::RandomGenerator *WSLPeer::_static_rng = nullptr;
+
+void WSLPeer::initialize() {
+ WebSocketPeer::_create = WSLPeer::_create;
+ _static_rng = memnew(CryptoCore::RandomGenerator);
+ _static_rng->init();
+}
+
+void WSLPeer::deinitialize() {
+ if (_static_rng) {
+ memdelete(_static_rng);
+ _static_rng = nullptr;
}
- return CryptoCore::b64_encode_str(&w[0], len);
}
-String WSLPeer::compute_key_response(String p_key) {
- String key = p_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // Magic UUID as per RFC
- Vector<uint8_t> sha = key.sha1_buffer();
- return CryptoCore::b64_encode_str(sha.ptr(), sha.size());
+///
+/// Resolver
+///
+void WSLPeer::Resolver::start(const String &p_host, int p_port) {
+ stop();
+
+ port = p_port;
+ if (p_host.is_valid_ip_address()) {
+ ip_candidates.push_back(IPAddress(p_host));
+ } else {
+ // Queue hostname for resolution.
+ resolver_id = IP::get_singleton()->resolve_hostname_queue_item(p_host);
+ ERR_FAIL_COND(resolver_id == IP::RESOLVER_INVALID_ID);
+ // Check if it was found in cache.
+ IP::ResolverStatus ip_status = IP::get_singleton()->get_resolve_item_status(resolver_id);
+ if (ip_status == IP::RESOLVER_STATUS_DONE) {
+ ip_candidates = IP::get_singleton()->get_resolve_item_addresses(resolver_id);
+ IP::get_singleton()->erase_resolve_item(resolver_id);
+ resolver_id = IP::RESOLVER_INVALID_ID;
+ }
+ }
}
-void WSLPeer::_wsl_destroy(struct PeerData **p_data) {
- if (!p_data || !(*p_data)) {
- return;
+void WSLPeer::Resolver::stop() {
+ if (resolver_id != IP::RESOLVER_INVALID_ID) {
+ IP::get_singleton()->erase_resolve_item(resolver_id);
+ resolver_id = IP::RESOLVER_INVALID_ID;
+ }
+ port = 0;
+}
+
+void WSLPeer::Resolver::try_next_candidate(Ref<StreamPeerTCP> &p_tcp) {
+ // Check if we still need resolving.
+ if (resolver_id != IP::RESOLVER_INVALID_ID) {
+ IP::ResolverStatus ip_status = IP::get_singleton()->get_resolve_item_status(resolver_id);
+ if (ip_status == IP::RESOLVER_STATUS_WAITING) {
+ return;
+ }
+ if (ip_status == IP::RESOLVER_STATUS_DONE) {
+ ip_candidates = IP::get_singleton()->get_resolve_item_addresses(resolver_id);
+ }
+ IP::get_singleton()->erase_resolve_item(resolver_id);
+ resolver_id = IP::RESOLVER_INVALID_ID;
}
- struct PeerData *data = *p_data;
- if (data->polling) {
- data->destroy = true;
+
+ // Try the current candidate if we have one.
+ if (p_tcp->get_status() != StreamPeerTCP::STATUS_NONE) {
+ p_tcp->poll();
+ StreamPeerTCP::Status status = p_tcp->get_status();
+ if (status == StreamPeerTCP::STATUS_CONNECTED) {
+ p_tcp->set_no_delay(true);
+ ip_candidates.clear();
+ return;
+ } else if (status == StreamPeerTCP::STATUS_CONNECTING) {
+ return; // Keep connecting.
+ } else {
+ p_tcp->disconnect_from_host();
+ }
+ }
+
+ // Keep trying next candidate.
+ while (ip_candidates.size()) {
+ Error err = p_tcp->connect_to_host(ip_candidates.pop_front(), port);
+ if (err == OK) {
+ return;
+ } else {
+ p_tcp->disconnect_from_host();
+ }
+ }
+}
+
+///
+/// Server functions
+///
+Error WSLPeer::accept_stream(Ref<StreamPeer> p_stream) {
+ ERR_FAIL_COND_V(wsl_ctx || tcp.is_valid(), ERR_ALREADY_IN_USE);
+ ERR_FAIL_COND_V(p_stream.is_null(), ERR_INVALID_PARAMETER);
+
+ _clear();
+
+ if (p_stream->is_class_ptr(StreamPeerTCP::get_class_ptr_static())) {
+ tcp = p_stream;
+ connection = p_stream;
+ use_tls = false;
+ } else if (p_stream->is_class_ptr(StreamPeerTLS::get_class_ptr_static())) {
+ Ref<StreamPeer> base_stream = static_cast<Ref<StreamPeerTLS>>(p_stream)->get_stream();
+ ERR_FAIL_COND_V(base_stream.is_null() || !base_stream->is_class_ptr(StreamPeerTCP::get_class_ptr_static()), ERR_INVALID_PARAMETER);
+ tcp = static_cast<Ref<StreamPeerTCP>>(base_stream);
+ connection = p_stream;
+ use_tls = true;
+ }
+ ERR_FAIL_COND_V(connection.is_null() || tcp.is_null(), ERR_INVALID_PARAMETER);
+ is_server = true;
+ ready_state = STATE_CONNECTING;
+ handshake_buffer->resize(WSL_MAX_HEADER_SIZE);
+ handshake_buffer->seek(0);
+ return OK;
+}
+
+bool WSLPeer::_parse_client_request() {
+ Vector<String> psa = String((const char *)handshake_buffer->get_data_array().ptr(), handshake_buffer->get_position() - 4).split("\r\n");
+ int len = psa.size();
+ ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers, got: " + itos(len) + ", expected >= 4.");
+
+ Vector<String> req = psa[0].split(" ", false);
+ ERR_FAIL_COND_V_MSG(req.size() < 2, false, "Invalid protocol or status code.");
+
+ // Wrong protocol
+ ERR_FAIL_COND_V_MSG(req[0] != "GET" || req[2] != "HTTP/1.1", false, "Invalid method or HTTP version.");
+
+ HashMap<String, String> headers;
+ for (int i = 1; i < len; i++) {
+ Vector<String> header = psa[i].split(":", false, 1);
+ ERR_FAIL_COND_V_MSG(header.size() != 2, false, "Invalid header -> " + psa[i]);
+ String name = header[0].to_lower();
+ String value = header[1].strip_edges();
+ if (headers.has(name)) {
+ headers[name] += "," + value;
+ } else {
+ headers[name] = value;
+ }
+ }
+ requested_host = headers.has("host") ? headers.get("host") : "";
+ requested_url = (use_tls ? "wss://" : "ws://") + requested_host + req[1];
+#define WSL_CHECK(NAME, VALUE) \
+ ERR_FAIL_COND_V_MSG(!headers.has(NAME) || headers[NAME].to_lower() != VALUE, false, \
+ "Missing or invalid header '" + String(NAME) + "'. Expected value '" + VALUE + "'.");
+#define WSL_CHECK_EX(NAME) \
+ ERR_FAIL_COND_V_MSG(!headers.has(NAME), false, "Missing header '" + String(NAME) + "'.");
+ WSL_CHECK("upgrade", "websocket");
+ WSL_CHECK("sec-websocket-version", "13");
+ WSL_CHECK_EX("sec-websocket-key");
+ WSL_CHECK_EX("connection");
+#undef WSL_CHECK_EX
+#undef WSL_CHECK
+ session_key = headers["sec-websocket-key"];
+ if (headers.has("sec-websocket-protocol")) {
+ Vector<String> protos = headers["sec-websocket-protocol"].split(",");
+ for (int i = 0; i < protos.size(); i++) {
+ String proto = protos[i].strip_edges();
+ // Check if we have the given protocol
+ for (int j = 0; j < supported_protocols.size(); j++) {
+ if (proto != supported_protocols[j]) {
+ continue;
+ }
+ selected_protocol = proto;
+ break;
+ }
+ // Found a protocol
+ if (!selected_protocol.is_empty()) {
+ break;
+ }
+ }
+ if (selected_protocol.is_empty()) { // Invalid protocol(s) requested
+ return false;
+ }
+ } else if (supported_protocols.size() > 0) { // No protocol requested, but we need one
+ return false;
+ }
+ return true;
+}
+
+Error WSLPeer::_do_server_handshake() {
+ if (use_tls) {
+ Ref<StreamPeerTLS> tls = static_cast<Ref<StreamPeerTLS>>(connection);
+ if (tls.is_null()) {
+ ERR_FAIL_V_MSG(ERR_BUG, "Couldn't get StreamPeerTLS for WebSocket handshake.");
+ close(-1);
+ return FAILED;
+ }
+ tls->poll();
+ if (tls->get_status() == StreamPeerTLS::STATUS_HANDSHAKING) {
+ return OK; // Pending handshake
+ } else if (tls->get_status() != StreamPeerTLS::STATUS_CONNECTED) {
+ print_verbose(vformat("WebSocket SSL connection error during handshake (StreamPeerTLS status code %d).", tls->get_status()));
+ close(-1);
+ return FAILED;
+ }
+ }
+
+ if (pending_request) {
+ int read = 0;
+ while (true) {
+ ERR_FAIL_COND_V_MSG(handshake_buffer->get_available_bytes() < 1, ERR_OUT_OF_MEMORY, "WebSocket response headers are too big.");
+ int pos = handshake_buffer->get_position();
+ uint8_t byte;
+ Error err = connection->get_partial_data(&byte, 1, read);
+ if (err != OK) { // Got an error
+ print_verbose(vformat("WebSocket error while getting partial data (StreamPeer error code %d).", err));
+ close(-1);
+ return FAILED;
+ } else if (read != 1) { // Busy, wait next poll
+ return OK;
+ }
+ handshake_buffer->put_u8(byte);
+ const char *r = (const char *)handshake_buffer->get_data_array().ptr();
+ int l = pos;
+ if (l > 3 && r[l] == '\n' && r[l - 1] == '\r' && r[l - 2] == '\n' && r[l - 3] == '\r') {
+ if (!_parse_client_request()) {
+ close(-1);
+ return FAILED;
+ }
+ String s = "HTTP/1.1 101 Switching Protocols\r\n";
+ s += "Upgrade: websocket\r\n";
+ s += "Connection: Upgrade\r\n";
+ s += "Sec-WebSocket-Accept: " + _compute_key_response(session_key) + "\r\n";
+ if (!selected_protocol.is_empty()) {
+ s += "Sec-WebSocket-Protocol: " + selected_protocol + "\r\n";
+ }
+ for (int i = 0; i < handshake_headers.size(); i++) {
+ s += handshake_headers[i] + "\r\n";
+ }
+ s += "\r\n";
+ CharString cs = s.utf8();
+ handshake_buffer->clear();
+ handshake_buffer->put_data((const uint8_t *)cs.get_data(), cs.length());
+ handshake_buffer->seek(0);
+ pending_request = false;
+ break;
+ }
+ }
+ }
+
+ if (pending_request) { // Still pending.
+ return OK;
+ }
+
+ int left = handshake_buffer->get_available_bytes();
+ if (left) {
+ Vector<uint8_t> data = handshake_buffer->get_data_array();
+ int pos = handshake_buffer->get_position();
+ int sent = 0;
+ Error err = connection->put_partial_data(data.ptr() + pos, left, sent);
+ if (err != OK) {
+ print_verbose(vformat("WebSocket error while putting partial data (StreamPeer error code %d).", err));
+ close(-1);
+ return err;
+ }
+ handshake_buffer->seek(pos + sent);
+ left -= sent;
+ if (left == 0) {
+ resolver.stop();
+ // Response sent, initialize wslay context.
+ wslay_event_context_server_init(&wsl_ctx, &_wsl_callbacks, this);
+ wslay_event_config_set_max_recv_msg_length(wsl_ctx, inbound_buffer_size);
+ in_buffer.resize(nearest_shift(inbound_buffer_size), max_queued_packets);
+ packet_buffer.resize(inbound_buffer_size);
+ ready_state = STATE_OPEN;
+ }
+ }
+
+ return OK;
+}
+
+///
+/// Client functions
+///
+void WSLPeer::_do_client_handshake() {
+ ERR_FAIL_COND(tcp.is_null());
+
+ // Try to connect to candidates.
+ if (resolver.has_more_candidates()) {
+ resolver.try_next_candidate(tcp);
+ if (resolver.has_more_candidates()) {
+ return; // Still pending.
+ }
+ }
+
+ tcp->poll();
+ if (tcp->get_status() == StreamPeerTCP::STATUS_CONNECTING) {
+ return; // Keep connecting.
+ } else if (tcp->get_status() != StreamPeerTCP::STATUS_CONNECTED) {
+ close(-1); // Failed to connect.
return;
}
- wslay_event_context_free(data->ctx);
- memdelete(data);
- *p_data = nullptr;
+
+ if (use_tls) {
+ Ref<StreamPeerTLS> tls;
+ if (connection == tcp) {
+ // Start SSL handshake
+ tls = Ref<StreamPeerTLS>(StreamPeerTLS::create());
+ ERR_FAIL_COND_MSG(tls.is_null(), "SSL is not available in this build.");
+ tls->set_blocking_handshake_enabled(false);
+ if (tls->connect_to_stream(tcp, verify_tls, requested_host, tls_cert) != OK) {
+ close(-1);
+ return; // Error.
+ }
+ connection = tls;
+ } else {
+ tls = static_cast<Ref<StreamPeerTLS>>(connection);
+ ERR_FAIL_COND(tls.is_null());
+ tls->poll();
+ }
+ if (tls->get_status() == StreamPeerTLS::STATUS_HANDSHAKING) {
+ return; // Need more polling.
+ } else if (tls->get_status() != StreamPeerTLS::STATUS_CONNECTED) {
+ close(-1);
+ return; // Error.
+ }
+ }
+
+ // Do websocket handshake.
+ if (pending_request) {
+ int left = handshake_buffer->get_available_bytes();
+ int pos = handshake_buffer->get_position();
+ const Vector<uint8_t> data = handshake_buffer->get_data_array();
+ int sent = 0;
+ Error err = connection->put_partial_data(data.ptr() + pos, left, sent);
+ // Sending handshake failed
+ if (err != OK) {
+ close(-1);
+ return; // Error.
+ }
+ handshake_buffer->seek(pos + sent);
+ if (handshake_buffer->get_available_bytes() == 0) {
+ pending_request = false;
+ handshake_buffer->clear();
+ handshake_buffer->resize(WSL_MAX_HEADER_SIZE);
+ handshake_buffer->seek(0);
+ }
+ } else {
+ int read = 0;
+ while (true) {
+ int left = handshake_buffer->get_available_bytes();
+ int pos = handshake_buffer->get_position();
+ if (left == 0) {
+ // Header is too big
+ close(-1);
+ ERR_FAIL_MSG("Response headers too big.");
+ return;
+ }
+
+ uint8_t byte;
+ Error err = connection->get_partial_data(&byte, 1, read);
+ if (err != OK) {
+ // Got some error.
+ close(-1);
+ return;
+ } else if (read != 1) {
+ // Busy, wait next poll.
+ break;
+ }
+ handshake_buffer->put_u8(byte);
+
+ // Check "\r\n\r\n" header terminator
+ const char *r = (const char *)handshake_buffer->get_data_array().ptr();
+ int l = pos;
+ if (l > 3 && r[l] == '\n' && r[l - 1] == '\r' && r[l - 2] == '\n' && r[l - 3] == '\r') {
+ // Response is over, verify headers and initialize wslay context/
+ if (!_verify_server_response()) {
+ close(-1);
+ ERR_FAIL_MSG("Invalid response headers.");
+ return;
+ }
+ wslay_event_context_client_init(&wsl_ctx, &_wsl_callbacks, this);
+ wslay_event_config_set_max_recv_msg_length(wsl_ctx, inbound_buffer_size);
+ in_buffer.resize(nearest_shift(inbound_buffer_size), max_queued_packets);
+ packet_buffer.resize(inbound_buffer_size);
+ ready_state = STATE_OPEN;
+ break;
+ }
+ }
+ }
}
-bool WSLPeer::_wsl_poll(struct PeerData *p_data) {
- p_data->polling = true;
- int err = 0;
- if ((err = wslay_event_recv(p_data->ctx)) != 0 || (err = wslay_event_send(p_data->ctx)) != 0) {
- print_verbose("Websocket (wslay) poll error: " + itos(err));
- p_data->destroy = true;
+bool WSLPeer::_verify_server_response() {
+ Vector<String> psa = String((const char *)handshake_buffer->get_data_array().ptr(), handshake_buffer->get_position() - 4).split("\r\n");
+ int len = psa.size();
+ ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers. Got: " + itos(len) + ", expected >= 4.");
+
+ Vector<String> req = psa[0].split(" ", false);
+ ERR_FAIL_COND_V_MSG(req.size() < 2, false, "Invalid protocol or status code. Got '" + psa[0] + "', expected 'HTTP/1.1 101'.");
+
+ // Wrong protocol
+ ERR_FAIL_COND_V_MSG(req[0] != "HTTP/1.1", false, "Invalid protocol. Got: '" + req[0] + "', expected 'HTTP/1.1'.");
+ ERR_FAIL_COND_V_MSG(req[1] != "101", false, "Invalid status code. Got: '" + req[1] + "', expected '101'.");
+
+ HashMap<String, String> headers;
+ for (int i = 1; i < len; i++) {
+ Vector<String> header = psa[i].split(":", false, 1);
+ ERR_FAIL_COND_V_MSG(header.size() != 2, false, "Invalid header -> " + psa[i] + ".");
+ String name = header[0].to_lower();
+ String value = header[1].strip_edges();
+ if (headers.has(name)) {
+ headers[name] += "," + value;
+ } else {
+ headers[name] = value;
+ }
}
- p_data->polling = false;
- if (p_data->destroy || (wslay_event_get_close_sent(p_data->ctx) && wslay_event_get_close_received(p_data->ctx))) {
- bool valid = p_data->valid;
- _wsl_destroy(&p_data);
- return valid;
+#define WSL_CHECK(NAME, VALUE) \
+ ERR_FAIL_COND_V_MSG(!headers.has(NAME) || headers[NAME].to_lower() != VALUE, false, \
+ "Missing or invalid header '" + String(NAME) + "'. Expected value '" + VALUE + "'.");
+#define WSL_CHECK_NC(NAME, VALUE) \
+ ERR_FAIL_COND_V_MSG(!headers.has(NAME) || headers[NAME] != VALUE, false, \
+ "Missing or invalid header '" + String(NAME) + "'. Expected value '" + VALUE + "'.");
+ WSL_CHECK("connection", "upgrade");
+ WSL_CHECK("upgrade", "websocket");
+ WSL_CHECK_NC("sec-websocket-accept", _compute_key_response(session_key));
+#undef WSL_CHECK_NC
+#undef WSL_CHECK
+ if (supported_protocols.size() == 0) {
+ // We didn't request a custom protocol
+ ERR_FAIL_COND_V_MSG(headers.has("sec-websocket-protocol"), false, "Received unrequested sub-protocol -> " + headers["sec-websocket-protocol"]);
+ } else {
+ // We requested at least one custom protocol but didn't receive one
+ ERR_FAIL_COND_V_MSG(!headers.has("sec-websocket-protocol"), false, "Requested sub-protocol(s) but received none.");
+ // Check received sub-protocol was one of those requested.
+ selected_protocol = headers["sec-websocket-protocol"];
+ bool valid = false;
+ for (int i = 0; i < supported_protocols.size(); i++) {
+ if (supported_protocols[i] != selected_protocol) {
+ continue;
+ }
+ valid = true;
+ break;
+ }
+ if (!valid) {
+ ERR_FAIL_V_MSG(false, "Received unrequested sub-protocol -> " + selected_protocol);
+ return false;
+ }
+ }
+ return true;
+}
+
+Error WSLPeer::connect_to_url(const String &p_url, bool p_verify_tls, Ref<X509Certificate> p_cert) {
+ ERR_FAIL_COND_V(wsl_ctx || tcp.is_valid(), ERR_ALREADY_IN_USE);
+ ERR_FAIL_COND_V(p_url.is_empty(), ERR_INVALID_PARAMETER);
+
+ _clear();
+
+ String host;
+ String path;
+ String scheme;
+ int port = 0;
+ Error err = p_url.parse_url(scheme, host, port, path);
+ ERR_FAIL_COND_V_MSG(err != OK, err, "Invalid URL: " + p_url);
+ if (scheme.is_empty()) {
+ scheme = "ws://";
+ }
+ ERR_FAIL_COND_V_MSG(scheme != "ws://" && scheme != "wss://", ERR_INVALID_PARAMETER, vformat("Invalid protocol: \"%s\" (must be either \"ws://\" or \"wss://\").", scheme));
+
+ use_tls = false;
+ if (scheme == "wss://") {
+ use_tls = true;
+ }
+ if (port == 0) {
+ port = use_tls ? 443 : 80;
+ }
+ if (path.is_empty()) {
+ path = "/";
+ }
+
+ requested_url = p_url;
+ requested_host = host;
+ verify_tls = p_verify_tls;
+ tls_cert = p_cert;
+ tcp.instantiate();
+
+ resolver.start(host, port);
+ resolver.try_next_candidate(tcp);
+
+ if (tcp->get_status() != StreamPeerTCP::STATUS_CONNECTING && tcp->get_status() != StreamPeerTCP::STATUS_CONNECTED && !resolver.has_more_candidates()) {
+ _clear();
+ return FAILED;
+ }
+ connection = tcp;
+
+ // Prepare handshake request.
+ session_key = _generate_key();
+ String request = "GET " + path + " HTTP/1.1\r\n";
+ String port_string;
+ if ((port != 80 && !use_tls) || (port != 443 && use_tls)) {
+ port_string = ":" + itos(port);
+ }
+ request += "Host: " + host + port_string + "\r\n";
+ request += "Upgrade: websocket\r\n";
+ request += "Connection: Upgrade\r\n";
+ request += "Sec-WebSocket-Key: " + session_key + "\r\n";
+ request += "Sec-WebSocket-Version: 13\r\n";
+ if (supported_protocols.size() > 0) {
+ request += "Sec-WebSocket-Protocol: ";
+ for (int i = 0; i < supported_protocols.size(); i++) {
+ if (i != 0) {
+ request += ",";
+ }
+ request += supported_protocols[i];
+ }
+ request += "\r\n";
}
- return false;
+ for (int i = 0; i < handshake_headers.size(); i++) {
+ request += handshake_headers[i] + "\r\n";
+ }
+ request += "\r\n";
+ CharString cs = request.utf8();
+ handshake_buffer->put_data((const uint8_t *)cs.get_data(), cs.length());
+ handshake_buffer->seek(0);
+ ready_state = STATE_CONNECTING;
+ is_server = false;
+ return OK;
}
-ssize_t wsl_recv_callback(wslay_event_context_ptr ctx, uint8_t *data, size_t len, int flags, void *user_data) {
- struct WSLPeer::PeerData *peer_data = (struct WSLPeer::PeerData *)user_data;
- if (!peer_data->valid) {
+///
+/// Callback functions.
+///
+ssize_t WSLPeer::_wsl_recv_callback(wslay_event_context_ptr ctx, uint8_t *data, size_t len, int flags, void *user_data) {
+ WSLPeer *peer = (WSLPeer *)user_data;
+ Ref<StreamPeer> conn = peer->connection;
+ if (conn.is_null()) {
wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
return -1;
}
- Ref<StreamPeer> conn = peer_data->conn;
int read = 0;
Error err = conn->get_partial_data(data, len, read);
if (err != OK) {
@@ -111,13 +577,13 @@ ssize_t wsl_recv_callback(wslay_event_context_ptr ctx, uint8_t *data, size_t len
return read;
}
-ssize_t wsl_send_callback(wslay_event_context_ptr ctx, const uint8_t *data, size_t len, int flags, void *user_data) {
- struct WSLPeer::PeerData *peer_data = (struct WSLPeer::PeerData *)user_data;
- if (!peer_data->valid) {
+ssize_t WSLPeer::_wsl_send_callback(wslay_event_context_ptr ctx, const uint8_t *data, size_t len, int flags, void *user_data) {
+ WSLPeer *peer = (WSLPeer *)user_data;
+ Ref<StreamPeer> conn = peer->connection;
+ if (conn.is_null()) {
wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
return -1;
}
- Ref<StreamPeer> conn = peer_data->conn;
int sent = 0;
Error err = conn->put_partial_data(data, len, sent);
if (err != OK) {
@@ -131,144 +597,142 @@ ssize_t wsl_send_callback(wslay_event_context_ptr ctx, const uint8_t *data, size
return sent;
}
-int wsl_genmask_callback(wslay_event_context_ptr ctx, uint8_t *buf, size_t len, void *user_data) {
- RandomNumberGenerator rng;
- // TODO maybe use crypto in the future?
- rng.set_seed(OS::get_singleton()->get_unix_time());
- for (unsigned int i = 0; i < len; i++) {
- buf[i] = (uint8_t)rng.randi_range(0, 255);
- }
+int WSLPeer::_wsl_genmask_callback(wslay_event_context_ptr ctx, uint8_t *buf, size_t len, void *user_data) {
+ ERR_FAIL_COND_V(!_static_rng, WSLAY_ERR_CALLBACK_FAILURE);
+ Error err = _static_rng->get_random_bytes(buf, len);
+ ERR_FAIL_COND_V(err != OK, WSLAY_ERR_CALLBACK_FAILURE);
return 0;
}
-void wsl_msg_recv_callback(wslay_event_context_ptr ctx, const struct wslay_event_on_msg_recv_arg *arg, void *user_data) {
- struct WSLPeer::PeerData *peer_data = (struct WSLPeer::PeerData *)user_data;
- if (!peer_data->valid || peer_data->closing) {
+void WSLPeer::_wsl_msg_recv_callback(wslay_event_context_ptr ctx, const struct wslay_event_on_msg_recv_arg *arg, void *user_data) {
+ WSLPeer *peer = (WSLPeer *)user_data;
+ uint8_t op = arg->opcode;
+
+ if (op == WSLAY_CONNECTION_CLOSE) {
+ // Close request or confirmation.
+ peer->close_code = arg->status_code;
+ size_t len = arg->msg_length;
+ peer->close_reason = "";
+ if (len > 2 /* first 2 bytes = close code */) {
+ peer->close_reason.parse_utf8((char *)arg->msg + 2, len - 2);
+ }
+ if (peer->ready_state == STATE_OPEN) {
+ peer->ready_state = STATE_CLOSING;
+ }
return;
}
- WSLPeer *peer = static_cast<WSLPeer *>(peer_data->peer);
- if (peer->parse_message(arg) != OK) {
+ if (peer->ready_state == STATE_CLOSING) {
return;
}
- if (peer_data->is_server) {
- WSLServer *helper = static_cast<WSLServer *>(peer_data->obj);
- helper->_on_peer_packet(peer_data->id);
- } else {
- WSLClient *helper = static_cast<WSLClient *>(peer_data->obj);
- helper->_on_peer_packet();
+ if (op == WSLAY_TEXT_FRAME || op == WSLAY_BINARY_FRAME) {
+ // Message.
+ uint8_t is_string = arg->opcode == WSLAY_TEXT_FRAME ? 1 : 0;
+ peer->in_buffer.write_packet(arg->msg, arg->msg_length, &is_string);
}
+ // Ping or pong.
}
-wslay_event_callbacks wsl_callbacks = {
- wsl_recv_callback,
- wsl_send_callback,
- wsl_genmask_callback,
+wslay_event_callbacks WSLPeer::_wsl_callbacks = {
+ _wsl_recv_callback,
+ _wsl_send_callback,
+ _wsl_genmask_callback,
nullptr, /* on_frame_recv_start_callback */
nullptr, /* on_frame_recv_callback */
nullptr, /* on_frame_recv_end_callback */
- wsl_msg_recv_callback
+ _wsl_msg_recv_callback
};
-Error WSLPeer::parse_message(const wslay_event_on_msg_recv_arg *arg) {
- uint8_t is_string = 0;
- if (arg->opcode == WSLAY_TEXT_FRAME) {
- is_string = 1;
- } else if (arg->opcode == WSLAY_CONNECTION_CLOSE) {
- close_code = arg->status_code;
- size_t len = arg->msg_length;
- close_reason = "";
- if (len > 2 /* first 2 bytes = close code */) {
- close_reason.parse_utf8((char *)arg->msg + 2, len - 2);
- }
- if (!wslay_event_get_close_sent(_data->ctx)) {
- if (_data->is_server) {
- WSLServer *helper = static_cast<WSLServer *>(_data->obj);
- helper->_on_close_request(_data->id, close_code, close_reason);
- } else {
- WSLClient *helper = static_cast<WSLClient *>(_data->obj);
- helper->_on_close_request(close_code, close_reason);
- }
- }
- return ERR_FILE_EOF;
- } else if (arg->opcode != WSLAY_BINARY_FRAME) {
- // Ping or pong
- return ERR_SKIP;
- }
- _in_buffer.write_packet(arg->msg, arg->msg_length, &is_string);
- return OK;
-}
-
-void WSLPeer::make_context(PeerData *p_data, unsigned int p_in_buf_size, unsigned int p_in_pkt_size, unsigned int p_out_buf_size, unsigned int p_out_pkt_size) {
- ERR_FAIL_COND(_data != nullptr);
- ERR_FAIL_COND(p_data == nullptr);
-
- _in_buffer.resize(p_in_pkt_size, p_in_buf_size);
- _packet_buffer.resize(1 << p_in_buf_size);
- _out_buf_size = p_out_buf_size;
- _out_pkt_size = p_out_pkt_size;
-
- _data = p_data;
- _data->peer = this;
- _data->valid = true;
-
- if (_data->is_server) {
- wslay_event_context_server_init(&(_data->ctx), &wsl_callbacks, _data);
- } else {
- wslay_event_context_client_init(&(_data->ctx), &wsl_callbacks, _data);
- }
- wslay_event_config_set_max_recv_msg_length(_data->ctx, (1ULL << p_in_buf_size));
-}
-
-void WSLPeer::set_write_mode(WriteMode p_mode) {
- write_mode = p_mode;
+String WSLPeer::_generate_key() {
+ // Random key
+ Vector<uint8_t> bkey;
+ int len = 16; // 16 bytes, as per RFC
+ bkey.resize(len);
+ _wsl_genmask_callback(nullptr, bkey.ptrw(), len, nullptr);
+ return CryptoCore::b64_encode_str(bkey.ptrw(), len);
}
-WSLPeer::WriteMode WSLPeer::get_write_mode() const {
- return write_mode;
+String WSLPeer::_compute_key_response(String p_key) {
+ String key = p_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // Magic UUID as per RFC
+ Vector<uint8_t> sha = key.sha1_buffer();
+ return CryptoCore::b64_encode_str(sha.ptr(), sha.size());
}
void WSLPeer::poll() {
- if (!_data) {
+ // Nothing to do.
+ if (ready_state == STATE_CLOSED) {
return;
}
- if (_wsl_poll(_data)) {
- _data = nullptr;
+ if (ready_state == STATE_CONNECTING) {
+ if (is_server) {
+ _do_server_handshake();
+ } else {
+ _do_client_handshake();
+ }
+ }
+
+ if (ready_state == STATE_OPEN || ready_state == STATE_CLOSING) {
+ ERR_FAIL_COND(!wsl_ctx);
+ int err = 0;
+ if ((err = wslay_event_recv(wsl_ctx)) != 0 || (err = wslay_event_send(wsl_ctx)) != 0) {
+ // Error close.
+ print_verbose("Websocket (wslay) poll error: " + itos(err));
+ wslay_event_context_free(wsl_ctx);
+ wsl_ctx = nullptr;
+ close(-1);
+ return;
+ }
+ if (wslay_event_get_close_sent(wsl_ctx) && wslay_event_get_close_received(wsl_ctx)) {
+ // Clean close.
+ wslay_event_context_free(wsl_ctx);
+ wsl_ctx = nullptr;
+ close(-1);
+ return;
+ }
}
}
-Error WSLPeer::put_packet(const uint8_t *p_buffer, int p_buffer_size) {
- ERR_FAIL_COND_V(!is_connected_to_host(), FAILED);
- ERR_FAIL_COND_V(_out_pkt_size && (wslay_event_get_queued_msg_count(_data->ctx) >= (1ULL << _out_pkt_size)), ERR_OUT_OF_MEMORY);
- ERR_FAIL_COND_V(_out_buf_size && (wslay_event_get_queued_msg_length(_data->ctx) + p_buffer_size >= (1ULL << _out_buf_size)), ERR_OUT_OF_MEMORY);
+Error WSLPeer::_send(const uint8_t *p_buffer, int p_buffer_size, wslay_opcode p_opcode) {
+ ERR_FAIL_COND_V(ready_state != STATE_OPEN, FAILED);
+ ERR_FAIL_COND_V(wslay_event_get_queued_msg_count(wsl_ctx) >= (uint32_t)max_queued_packets, ERR_OUT_OF_MEMORY);
+ ERR_FAIL_COND_V(outbound_buffer_size > 0 && (wslay_event_get_queued_msg_length(wsl_ctx) + p_buffer_size > (uint32_t)outbound_buffer_size), ERR_OUT_OF_MEMORY);
struct wslay_event_msg msg;
- msg.opcode = write_mode == WRITE_MODE_TEXT ? WSLAY_TEXT_FRAME : WSLAY_BINARY_FRAME;
+ msg.opcode = p_opcode;
msg.msg = p_buffer;
msg.msg_length = p_buffer_size;
// Queue & send message.
- if (wslay_event_queue_msg(_data->ctx, &msg) != 0 || wslay_event_send(_data->ctx) != 0) {
- close_now();
+ if (wslay_event_queue_msg(wsl_ctx, &msg) != 0 || wslay_event_send(wsl_ctx) != 0) {
+ close(-1);
return FAILED;
}
return OK;
}
+Error WSLPeer::send(const uint8_t *p_buffer, int p_buffer_size, WriteMode p_mode) {
+ wslay_opcode opcode = p_mode == WRITE_MODE_TEXT ? WSLAY_TEXT_FRAME : WSLAY_BINARY_FRAME;
+ return _send(p_buffer, p_buffer_size, opcode);
+}
+
+Error WSLPeer::put_packet(const uint8_t *p_buffer, int p_buffer_size) {
+ return _send(p_buffer, p_buffer_size, WSLAY_BINARY_FRAME);
+}
+
Error WSLPeer::get_packet(const uint8_t **r_buffer, int &r_buffer_size) {
r_buffer_size = 0;
- ERR_FAIL_COND_V(!is_connected_to_host(), FAILED);
+ ERR_FAIL_COND_V(ready_state != STATE_OPEN, FAILED);
- if (_in_buffer.packets_left() == 0) {
+ if (in_buffer.packets_left() == 0) {
return ERR_UNAVAILABLE;
}
int read = 0;
- uint8_t *rw = _packet_buffer.ptrw();
- _in_buffer.read_packet(rw, _packet_buffer.size(), &_is_string, read);
+ uint8_t *rw = packet_buffer.ptrw();
+ in_buffer.read_packet(rw, packet_buffer.size(), &was_string, read);
*r_buffer = rw;
r_buffer_size = read;
@@ -277,75 +741,106 @@ Error WSLPeer::get_packet(const uint8_t **r_buffer, int &r_buffer_size) {
}
int WSLPeer::get_available_packet_count() const {
- if (!is_connected_to_host()) {
+ if (ready_state != STATE_OPEN) {
return 0;
}
- return _in_buffer.packets_left();
+ return in_buffer.packets_left();
}
int WSLPeer::get_current_outbound_buffered_amount() const {
- ERR_FAIL_COND_V(!_data, 0);
-
- return wslay_event_get_queued_msg_length(_data->ctx);
-}
-
-bool WSLPeer::was_string_packet() const {
- return _is_string;
-}
-
-bool WSLPeer::is_connected_to_host() const {
- return _data != nullptr;
-}
+ if (ready_state != STATE_OPEN) {
+ return 0;
+ }
-void WSLPeer::close_now() {
- close(1000, "");
- _wsl_destroy(&_data);
+ return wslay_event_get_queued_msg_length(wsl_ctx);
}
void WSLPeer::close(int p_code, String p_reason) {
- if (_data && !wslay_event_get_close_sent(_data->ctx)) {
+ if (p_code < 0) {
+ // Force immediate close.
+ ready_state = STATE_CLOSED;
+ }
+
+ if (ready_state == STATE_OPEN && !wslay_event_get_close_sent(wsl_ctx)) {
CharString cs = p_reason.utf8();
- wslay_event_queue_close(_data->ctx, p_code, (uint8_t *)cs.ptr(), cs.size());
- wslay_event_send(_data->ctx);
- _data->closing = true;
+ wslay_event_queue_close(wsl_ctx, p_code, (uint8_t *)cs.ptr(), cs.length());
+ wslay_event_send(wsl_ctx);
+ ready_state = STATE_CLOSING;
+ } else if (ready_state == STATE_CONNECTING || ready_state == STATE_CLOSED) {
+ ready_state = STATE_CLOSED;
+ connection.unref();
+ if (tcp.is_valid()) {
+ tcp->disconnect_from_host();
+ tcp.unref();
+ }
}
- _in_buffer.clear();
- _packet_buffer.resize(0);
+ in_buffer.clear();
+ packet_buffer.resize(0);
}
IPAddress WSLPeer::get_connected_host() const {
- ERR_FAIL_COND_V(!is_connected_to_host() || _data->tcp.is_null(), IPAddress());
-
- return _data->tcp->get_connected_host();
+ ERR_FAIL_COND_V(tcp.is_null(), IPAddress());
+ return tcp->get_connected_host();
}
uint16_t WSLPeer::get_connected_port() const {
- ERR_FAIL_COND_V(!is_connected_to_host() || _data->tcp.is_null(), 0);
+ ERR_FAIL_COND_V(tcp.is_null(), 0);
+ return tcp->get_connected_port();
+}
+
+String WSLPeer::get_selected_protocol() const {
+ return selected_protocol;
+}
- return _data->tcp->get_connected_port();
+String WSLPeer::get_requested_url() const {
+ return requested_url;
}
void WSLPeer::set_no_delay(bool p_enabled) {
- ERR_FAIL_COND(!is_connected_to_host() || _data->tcp.is_null());
- _data->tcp->set_no_delay(p_enabled);
+ ERR_FAIL_COND(tcp.is_null());
+ tcp->set_no_delay(p_enabled);
}
-void WSLPeer::invalidate() {
- if (_data) {
- _data->valid = false;
+void WSLPeer::_clear() {
+ // Connection info.
+ ready_state = STATE_CLOSED;
+ is_server = false;
+ connection.unref();
+ if (tcp.is_valid()) {
+ tcp->disconnect_from_host();
+ tcp.unref();
}
+ if (wsl_ctx) {
+ wslay_event_context_free(wsl_ctx);
+ wsl_ctx = nullptr;
+ }
+
+ resolver.stop();
+ requested_url.clear();
+ requested_host.clear();
+ pending_request = true;
+ handshake_buffer->clear();
+ selected_protocol.clear();
+ session_key.clear();
+
+ // Pending packets info.
+ was_string = 0;
+ in_buffer.clear();
+ packet_buffer.clear();
+
+ // Close code info.
+ close_code = -1;
+ close_reason.clear();
}
WSLPeer::WSLPeer() {
+ handshake_buffer.instantiate();
}
WSLPeer::~WSLPeer() {
- close();
- invalidate();
- _wsl_destroy(&_data);
- _data = nullptr;
+ close(-1);
}
#endif // WEB_ENABLED