diff options
Diffstat (limited to 'modules/websocket/wsl_client.cpp')
-rw-r--r-- | modules/websocket/wsl_client.cpp | 90 |
1 files changed, 39 insertions, 51 deletions
diff --git a/modules/websocket/wsl_client.cpp b/modules/websocket/wsl_client.cpp index 86374e8f80..a422f65cfc 100644 --- a/modules/websocket/wsl_client.cpp +++ b/modules/websocket/wsl_client.cpp @@ -53,8 +53,7 @@ void WSLClient::_do_handshake() { // Header is too big disconnect_from_host(); _on_error(); - ERR_EXPLAIN("Response headers too big"); - ERR_FAIL(); + ERR_FAIL_MSG("Response headers too big."); } Error err = _connection->get_partial_data(&_resp_buf[_resp_pos], 1, read); if (err == ERR_FILE_EOF) { @@ -81,13 +80,13 @@ void WSLClient::_do_handshake() { if (!_verify_headers(protocol)) { disconnect_from_host(); _on_error(); - ERR_EXPLAIN("Invalid response headers"); - ERR_FAIL(); + ERR_FAIL_MSG("Invalid response headers."); } // Create peer. WSLPeer::PeerData *data = memnew(struct WSLPeer::PeerData); data->obj = this; data->conn = _connection; + data->tcp = _tcp; data->is_server = false; data->id = 1; _peer->make_context(data, _in_buf_size, _in_pkt_size, _out_buf_size, _out_pkt_size); @@ -103,29 +102,18 @@ bool WSLClient::_verify_headers(String &r_protocol) { String s = (char *)_resp_buf; Vector<String> psa = s.split("\r\n"); int len = psa.size(); - if (len < 4) { - ERR_EXPLAIN("Not enough response headers."); - ERR_FAIL_V(false); - } + ERR_FAIL_COND_V_MSG(len < 4, false, "Not enough response headers, got: " + itos(len) + ", expected >= 4."); Vector<String> req = psa[0].split(" ", false); - if (req.size() < 2) { - ERR_EXPLAIN("Invalid protocol or status code."); - ERR_FAIL_V(false); - } + ERR_FAIL_COND_V_MSG(req.size() < 2, false, "Invalid protocol or status code."); + // Wrong protocol - if (req[0] != "HTTP/1.1" || req[1] != "101") { - ERR_EXPLAIN("Invalid protocol or status code."); - ERR_FAIL_V(false); - } + ERR_FAIL_COND_V_MSG(req[0] != "HTTP/1.1" || req[1] != "101", false, "Invalid protocol or status code."); Map<String, String> headers; for (int i = 1; i < len; i++) { Vector<String> header = psa[i].split(":", false, 1); - if (header.size() != 2) { - ERR_EXPLAIN("Invalid header -> " + psa[i]); - ERR_FAIL_V(false); - } + ERR_FAIL_COND_V_MSG(header.size() != 2, false, "Invalid header -> " + psa[i] + "."); String name = header[0].to_lower(); String value = header[1].strip_edges(); if (headers.has(name)) @@ -134,17 +122,17 @@ bool WSLClient::_verify_headers(String &r_protocol) { headers[name] = value; } -#define _WLS_EXPLAIN(NAME, VALUE) \ - ERR_EXPLAIN("Missing or invalid header '" + String(NAME) + "'. Expected value '" + VALUE + "'"); -#define _WLS_CHECK(NAME, VALUE) \ - _WLS_EXPLAIN(NAME, VALUE); \ - ERR_FAIL_COND_V(!headers.has(NAME) || headers[NAME].to_lower() != VALUE, false); -#define _WLS_CHECK_NC(NAME, VALUE) \ - _WLS_EXPLAIN(NAME, VALUE); \ - ERR_FAIL_COND_V(!headers.has(NAME) || headers[NAME] != VALUE, false); - _WLS_CHECK("connection", "upgrade"); - _WLS_CHECK("upgrade", "websocket"); - _WLS_CHECK_NC("sec-websocket-accept", WSLPeer::compute_key_response(_key)); +#define _WSL_CHECK(NAME, VALUE) \ + ERR_FAIL_COND_V_MSG(!headers.has(NAME) || headers[NAME].to_lower() != VALUE, false, \ + "Missing or invalid header '" + String(NAME) + "'. Expected value '" + VALUE + "'."); +#define _WSL_CHECK_NC(NAME, VALUE) \ + ERR_FAIL_COND_V_MSG(!headers.has(NAME) || headers[NAME] != VALUE, false, \ + "Missing or invalid header '" + String(NAME) + "'. Expected value '" + VALUE + "'."); + _WSL_CHECK("connection", "upgrade"); + _WSL_CHECK("upgrade", "websocket"); + _WSL_CHECK_NC("sec-websocket-accept", WSLPeer::compute_key_response(_key)); +#undef _WSL_CHECK_NC +#undef _WSL_CHECK if (_protocols.size() == 0) { // We didn't request a custom protocol ERR_FAIL_COND_V(headers.has("sec-websocket-protocol"), false); @@ -161,14 +149,10 @@ bool WSLClient::_verify_headers(String &r_protocol) { if (!valid) return false; } -#undef _WLS_CHECK_NC -#undef _WLS_CHECK -#undef _WLS_EXPLAIN - return true; } -Error WSLClient::connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, PoolVector<String> p_protocols) { +Error WSLClient::connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, const Vector<String> p_protocols, const Vector<String> p_custom_headers) { ERR_FAIL_COND_V(_connection.is_valid(), ERR_ALREADY_IN_USE); @@ -190,14 +174,15 @@ Error WSLClient::connect_to_host(String p_host, String p_path, uint16_t p_port, Error err = _tcp->connect_to_host(addr, p_port); if (err != OK) { - _on_error(); _tcp->disconnect_from_host(); + _on_error(); return err; } _connection = _tcp; _use_ssl = p_ssl; _host = p_host; - _protocols = p_protocols; + _protocols.clear(); + _protocols.append_array(p_protocols); _key = WSLPeer::generate_key(); // TODO custom extra headers (allow overriding this too?) @@ -216,6 +201,9 @@ Error WSLClient::connect_to_host(String p_host, String p_path, uint16_t p_port, } request += "\r\n"; } + for (int i = 0; i < p_custom_headers.size(); i++) { + request += p_custom_headers[i] + "\r\n"; + } request += "\r\n"; _request = request.utf8(); @@ -230,8 +218,8 @@ void WSLClient::poll() { if (_peer->is_connected_to_host()) { _peer->poll(); if (!_peer->is_connected_to_host()) { - _on_disconnect(_peer->close_code != -1); disconnect_from_host(); + _on_disconnect(_peer->close_code != -1); } return; } @@ -242,8 +230,8 @@ void WSLClient::poll() { switch (_tcp->get_status()) { case StreamPeerTCP::STATUS_NONE: // Clean close - _on_error(); disconnect_from_host(); + _on_error(); break; case StreamPeerTCP::STATUS_CONNECTED: { Ref<StreamPeerSSL> ssl; @@ -251,12 +239,11 @@ void WSLClient::poll() { if (_connection == _tcp) { // Start SSL handshake ssl = Ref<StreamPeerSSL>(StreamPeerSSL::create()); - ERR_EXPLAIN("SSL is not available in this build"); - ERR_FAIL_COND(ssl.is_null()); + ERR_FAIL_COND_MSG(ssl.is_null(), "SSL is not available in this build."); ssl->set_blocking_handshake_enabled(false); - if (ssl->connect_to_stream(_tcp, verify_ssl, _host) != OK) { - _on_error(); + if (ssl->connect_to_stream(_tcp, verify_ssl, _host, ssl_cert) != OK) { disconnect_from_host(); + _on_error(); return; } _connection = ssl; @@ -268,8 +255,8 @@ void WSLClient::poll() { if (ssl->get_status() == StreamPeerSSL::STATUS_HANDSHAKING) return; // Need more polling. else if (ssl->get_status() != StreamPeerSSL::STATUS_CONNECTED) { - _on_error(); disconnect_from_host(); + _on_error(); return; // Error. } } @@ -277,8 +264,8 @@ void WSLClient::poll() { _do_handshake(); } break; case StreamPeerTCP::STATUS_ERROR: - _on_error(); disconnect_from_host(); + _on_error(); break; case StreamPeerTCP::STATUS_CONNECTING: break; // Wait for connection @@ -311,7 +298,7 @@ void WSLClient::disconnect_from_host(int p_code, String p_reason) { _key = ""; _host = ""; - _protocols.resize(0); + _protocols.clear(); _use_ssl = false; _request = ""; @@ -323,17 +310,18 @@ void WSLClient::disconnect_from_host(int p_code, String p_reason) { IP_Address WSLClient::get_connected_host() const { - return IP_Address(); + ERR_FAIL_COND_V(!_peer->is_connected_to_host(), IP_Address()); + return _peer->get_connected_host(); } uint16_t WSLClient::get_connected_port() const { - return 1025; + ERR_FAIL_COND_V(!_peer->is_connected_to_host(), 0); + return _peer->get_connected_port(); } Error WSLClient::set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) { - ERR_EXPLAIN("Buffers sizes can only be set before listening or connecting"); - ERR_FAIL_COND_V(_connection.is_valid(), FAILED); + ERR_FAIL_COND_V_MSG(_connection.is_valid(), FAILED, "Buffers sizes can only be set before listening or connecting."); _in_buf_size = nearest_shift(p_in_buffer - 1) + 10; _in_pkt_size = nearest_shift(p_in_packets - 1); |