This source file includes following definitions.
- pool_
- CreateClientSocketHandle
- OnSuccess
- OnFailure
- OnStartOpeningHandshake
- OnFinishOpeningHandshake
- CreateAndInitializeStream
- TEST_F
- TEST_F
- TEST_F
- TEST_F
#include "net/websockets/websocket_handshake_stream_create_helper.h"
#include <string>
#include <vector>
#include "net/base/completion_callback.h"
#include "net/base/net_errors.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_request_info.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_response_info.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/socket_test_util.h"
#include "net/websockets/websocket_basic_handshake_stream.h"
#include "net/websockets/websocket_stream.h"
#include "net/websockets/websocket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
namespace net {
namespace {
class MockClientSocketHandleFactory {
public:
MockClientSocketHandleFactory()
: histograms_("a"),
pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {}
scoped_ptr<ClientSocketHandle> CreateClientSocketHandle(
const std::string& expect_written,
const std::string& return_to_read) {
socket_factory_maker_.SetExpectations(expect_written, return_to_read);
scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle);
socket_handle->Init(
"a",
scoped_refptr<MockTransportSocketParams>(),
MEDIUM,
CompletionCallback(),
&pool_,
BoundNetLog());
return socket_handle.Pass();
}
private:
WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_;
ClientSocketPoolHistograms histograms_;
MockTransportClientSocketPool pool_;
DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory);
};
class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
public:
virtual ~TestConnectDelegate() {}
virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) OVERRIDE {}
virtual void OnFailure(const std::string& failure_message) OVERRIDE {}
virtual void OnStartOpeningHandshake(
scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE {}
virtual void OnFinishOpeningHandshake(
scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {}
};
class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test {
protected:
scoped_ptr<WebSocketStream> CreateAndInitializeStream(
const std::string& socket_url,
const std::string& socket_path,
const std::vector<std::string>& sub_protocols,
const std::string& origin,
const std::string& extra_request_headers,
const std::string& extra_response_headers) {
WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_,
sub_protocols);
scoped_ptr<ClientSocketHandle> socket_handle =
socket_handle_factory_.CreateClientSocketHandle(
WebSocketStandardRequest(
socket_path, origin, extra_request_headers),
WebSocketStandardResponse(extra_response_headers));
scoped_ptr<WebSocketHandshakeStreamBase> handshake(
create_helper.CreateBasicStream(socket_handle.Pass(), false));
static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
HttpRequestInfo request_info;
request_info.url = GURL(socket_url);
request_info.method = "GET";
request_info.load_flags = LOAD_DISABLE_CACHE | LOAD_DO_NOT_PROMPT_FOR_LOGIN;
int rv = handshake->InitializeStream(
&request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback());
EXPECT_EQ(OK, rv);
HttpRequestHeaders headers;
headers.SetHeader("Host", "localhost");
headers.SetHeader("Connection", "Upgrade");
headers.SetHeader("Pragma", "no-cache");
headers.SetHeader("Cache-Control", "no-cache");
headers.SetHeader("Upgrade", "websocket");
headers.SetHeader("Origin", origin);
headers.SetHeader("Sec-WebSocket-Version", "13");
headers.SetHeader("User-Agent", "");
headers.SetHeader("Accept-Encoding", "gzip,deflate");
headers.SetHeader("Accept-Language", "en-us,fr");
HttpResponseInfo response;
TestCompletionCallback dummy;
rv = handshake->SendRequest(headers, &response, dummy.callback());
EXPECT_EQ(OK, rv);
rv = handshake->ReadResponseHeaders(dummy.callback());
EXPECT_EQ(OK, rv);
EXPECT_EQ(101, response.headers->response_code());
EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
return handshake->Upgrade();
}
MockClientSocketHandleFactory socket_handle_factory_;
TestConnectDelegate connect_delegate_;
};
TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
scoped_ptr<WebSocketStream> stream =
CreateAndInitializeStream("ws://localhost/", "/",
std::vector<std::string>(), "http://localhost/",
"", "");
EXPECT_EQ("", stream->GetExtensions());
EXPECT_EQ("", stream->GetSubProtocol());
}
TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
std::vector<std::string> sub_protocols;
sub_protocols.push_back("chat");
sub_protocols.push_back("superchat");
scoped_ptr<WebSocketStream> stream =
CreateAndInitializeStream("ws://localhost/",
"/",
sub_protocols,
"http://localhost/",
"Sec-WebSocket-Protocol: chat, superchat\r\n",
"Sec-WebSocket-Protocol: superchat\r\n");
EXPECT_EQ("superchat", stream->GetSubProtocol());
}
TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
"ws://localhost/",
"/",
std::vector<std::string>(),
"http://localhost/",
"",
"Sec-WebSocket-Extensions: permessage-deflate\r\n");
EXPECT_EQ("permessage-deflate", stream->GetExtensions());
}
TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
"ws://localhost/",
"/",
std::vector<std::string>(),
"http://localhost/",
"",
"Sec-WebSocket-Extensions: permessage-deflate;"
" client_max_window_bits=14; server_max_window_bits=14;"
" server_no_context_takeover; client_no_context_takeover\r\n");
EXPECT_EQ(
"permessage-deflate;"
" client_max_window_bits=14; server_max_window_bits=14;"
" server_no_context_takeover; client_no_context_takeover",
stream->GetExtensions());
}
}
}