This source file includes following definitions.
- Connect
- Disconnect
- IsConnected
- IsConnectedAndIdle
- GetPeerAddress
- GetLocalAddress
- SetSubresourceSpeculation
- SetOmniboxSpeculation
- WasEverUsed
- UsingTCPFastOpen
- WasNpnNegotiated
- GetNegotiatedProtocol
- GetSSLInfo
- Read
- Write
- SetReceiveBufferSize
- SetSendBufferSize
- buffer_size_
- SetBufferSize
- DoLoop
- DoRead
- DoReadComplete
- OnReadCompleted
- SetNextReadError
- SetNextWriteError
- pending_write_error_
- Read
- Write
- Read
- Write
- SetNextReadShouldBlock
- UnblockRead
- SetNextWriteShouldBlock
- UnblockWrite
- write_state_
- pending_result_
- SetShouldBlock
- Unblock
- RunWrappedFunction
- OnCompleted
- callback_
- callback
- OnComplete
- transport_security_state_
- CreateSSLClientSocket
- GetCertRequest
- LogContainsSSLConnectEndEvent
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
- TEST_F
#include "net/socket/ssl_client_socket.h"
#include "base/callback_helpers.h"
#include "base/memory/ref_counted.h"
#include "net/base/address_list.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/net_log.h"
#include "net/base/net_log_unittest.h"
#include "net/base/test_completion_callback.h"
#include "net/base/test_data_directory.h"
#include "net/cert/mock_cert_verifier.h"
#include "net/cert/test_root_certs.h"
#include "net/dns/host_resolver.h"
#include "net/http/transport_security_state.h"
#include "net/socket/client_socket_factory.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/tcp_client_socket.h"
#include "net/ssl/ssl_cert_request_info.h"
#include "net/ssl/ssl_config_service.h"
#include "net/test/cert_test_util.h"
#include "net/test/spawned_test_server/spawned_test_server.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/platform_test.h"
namespace net {
namespace {
const SSLConfig kDefaultSSLConfig;
class WrappedStreamSocket : public StreamSocket {
public:
explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport)
: transport_(transport.Pass()) {}
virtual ~WrappedStreamSocket() {}
virtual int Connect(const CompletionCallback& callback) OVERRIDE {
return transport_->Connect(callback);
}
virtual void Disconnect() OVERRIDE { transport_->Disconnect(); }
virtual bool IsConnected() const OVERRIDE {
return transport_->IsConnected();
}
virtual bool IsConnectedAndIdle() const OVERRIDE {
return transport_->IsConnectedAndIdle();
}
virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
return transport_->GetPeerAddress(address);
}
virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
return transport_->GetLocalAddress(address);
}
virtual const BoundNetLog& NetLog() const OVERRIDE {
return transport_->NetLog();
}
virtual void SetSubresourceSpeculation() OVERRIDE {
transport_->SetSubresourceSpeculation();
}
virtual void SetOmniboxSpeculation() OVERRIDE {
transport_->SetOmniboxSpeculation();
}
virtual bool WasEverUsed() const OVERRIDE {
return transport_->WasEverUsed();
}
virtual bool UsingTCPFastOpen() const OVERRIDE {
return transport_->UsingTCPFastOpen();
}
virtual bool WasNpnNegotiated() const OVERRIDE {
return transport_->WasNpnNegotiated();
}
virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
return transport_->GetNegotiatedProtocol();
}
virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
return transport_->GetSSLInfo(ssl_info);
}
virtual int Read(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) OVERRIDE {
return transport_->Read(buf, buf_len, callback);
}
virtual int Write(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) OVERRIDE {
return transport_->Write(buf, buf_len, callback);
}
virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
return transport_->SetReceiveBufferSize(size);
}
virtual int SetSendBufferSize(int32 size) OVERRIDE {
return transport_->SetSendBufferSize(size);
}
protected:
scoped_ptr<StreamSocket> transport_;
};
class ReadBufferingStreamSocket : public WrappedStreamSocket {
public:
explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport);
virtual ~ReadBufferingStreamSocket() {}
virtual int Read(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) OVERRIDE;
void SetBufferSize(int size);
private:
enum State {
STATE_NONE,
STATE_READ,
STATE_READ_COMPLETE,
};
int DoLoop(int result);
int DoRead();
int DoReadComplete(int result);
void OnReadCompleted(int result);
State state_;
scoped_refptr<GrowableIOBuffer> read_buffer_;
int buffer_size_;
scoped_refptr<IOBuffer> user_read_buf_;
CompletionCallback user_read_callback_;
};
ReadBufferingStreamSocket::ReadBufferingStreamSocket(
scoped_ptr<StreamSocket> transport)
: WrappedStreamSocket(transport.Pass()),
read_buffer_(new GrowableIOBuffer()),
buffer_size_(0) {}
void ReadBufferingStreamSocket::SetBufferSize(int size) {
DCHECK(!user_read_buf_.get());
buffer_size_ = size;
read_buffer_->SetCapacity(size);
}
int ReadBufferingStreamSocket::Read(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) {
if (buffer_size_ == 0)
return transport_->Read(buf, buf_len, callback);
if (buf_len < buffer_size_)
return ERR_UNEXPECTED;
state_ = STATE_READ;
user_read_buf_ = buf;
int result = DoLoop(OK);
if (result == ERR_IO_PENDING)
user_read_callback_ = callback;
else
user_read_buf_ = NULL;
return result;
}
int ReadBufferingStreamSocket::DoLoop(int result) {
int rv = result;
do {
State current_state = state_;
state_ = STATE_NONE;
switch (current_state) {
case STATE_READ:
rv = DoRead();
break;
case STATE_READ_COMPLETE:
rv = DoReadComplete(rv);
break;
case STATE_NONE:
default:
NOTREACHED() << "Unexpected state: " << current_state;
rv = ERR_UNEXPECTED;
break;
}
} while (rv != ERR_IO_PENDING && state_ != STATE_NONE);
return rv;
}
int ReadBufferingStreamSocket::DoRead() {
state_ = STATE_READ_COMPLETE;
int rv =
transport_->Read(read_buffer_.get(),
read_buffer_->RemainingCapacity(),
base::Bind(&ReadBufferingStreamSocket::OnReadCompleted,
base::Unretained(this)));
return rv;
}
int ReadBufferingStreamSocket::DoReadComplete(int result) {
state_ = STATE_NONE;
if (result <= 0)
return result;
read_buffer_->set_offset(read_buffer_->offset() + result);
if (read_buffer_->RemainingCapacity() > 0) {
state_ = STATE_READ;
return OK;
}
memcpy(user_read_buf_->data(),
read_buffer_->StartOfBuffer(),
read_buffer_->capacity());
read_buffer_->set_offset(0);
return read_buffer_->capacity();
}
void ReadBufferingStreamSocket::OnReadCompleted(int result) {
result = DoLoop(result);
if (result == ERR_IO_PENDING)
return;
user_read_buf_ = NULL;
base::ResetAndReturn(&user_read_callback_).Run(result);
}
class SynchronousErrorStreamSocket : public WrappedStreamSocket {
public:
explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport);
virtual ~SynchronousErrorStreamSocket() {}
virtual int Read(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) OVERRIDE;
virtual int Write(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) OVERRIDE;
void SetNextReadError(Error error) {
DCHECK_GE(0, error);
have_read_error_ = true;
pending_read_error_ = error;
}
void SetNextWriteError(Error error) {
DCHECK_GE(0, error);
have_write_error_ = true;
pending_write_error_ = error;
}
private:
bool have_read_error_;
int pending_read_error_;
bool have_write_error_;
int pending_write_error_;
DISALLOW_COPY_AND_ASSIGN(SynchronousErrorStreamSocket);
};
SynchronousErrorStreamSocket::SynchronousErrorStreamSocket(
scoped_ptr<StreamSocket> transport)
: WrappedStreamSocket(transport.Pass()),
have_read_error_(false),
pending_read_error_(OK),
have_write_error_(false),
pending_write_error_(OK) {}
int SynchronousErrorStreamSocket::Read(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) {
if (have_read_error_)
return pending_read_error_;
return transport_->Read(buf, buf_len, callback);
}
int SynchronousErrorStreamSocket::Write(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) {
if (have_write_error_)
return pending_write_error_;
return transport_->Write(buf, buf_len, callback);
}
class FakeBlockingStreamSocket : public WrappedStreamSocket {
public:
explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport);
virtual ~FakeBlockingStreamSocket() {}
virtual int Read(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) OVERRIDE {
return read_state_.RunWrappedFunction(buf, buf_len, callback);
}
virtual int Write(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) OVERRIDE {
return write_state_.RunWrappedFunction(buf, buf_len, callback);
}
void SetNextReadShouldBlock() { read_state_.SetShouldBlock(); }
void UnblockRead() { read_state_.Unblock(); }
void SetNextWriteShouldBlock() { write_state_.SetShouldBlock(); }
void UnblockWrite() { write_state_.Unblock(); }
private:
class BlockingState {
public:
typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)>
WrappedSocketFunction;
explicit BlockingState(const WrappedSocketFunction& function);
~BlockingState() {}
void SetShouldBlock();
void Unblock();
int RunWrappedFunction(IOBuffer* buf,
int len,
const CompletionCallback& user_callback);
private:
void OnCompleted(int result);
WrappedSocketFunction wrapped_function_;
bool should_block_;
bool have_result_;
int pending_result_;
CompletionCallback user_callback_;
};
BlockingState read_state_;
BlockingState write_state_;
DISALLOW_COPY_AND_ASSIGN(FakeBlockingStreamSocket);
};
FakeBlockingStreamSocket::FakeBlockingStreamSocket(
scoped_ptr<StreamSocket> transport)
: WrappedStreamSocket(transport.Pass()),
read_state_(base::Bind(&Socket::Read,
base::Unretained(transport_.get()))),
write_state_(base::Bind(&Socket::Write,
base::Unretained(transport_.get()))) {}
FakeBlockingStreamSocket::BlockingState::BlockingState(
const WrappedSocketFunction& function)
: wrapped_function_(function),
should_block_(false),
have_result_(false),
pending_result_(OK) {}
void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() {
DCHECK(!should_block_);
should_block_ = true;
}
void FakeBlockingStreamSocket::BlockingState::Unblock() {
DCHECK(should_block_);
should_block_ = false;
if (!have_result_)
return;
have_result_ = false;
base::ResetAndReturn(&user_callback_).Run(pending_result_);
}
int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction(
IOBuffer* buf,
int len,
const CompletionCallback& callback) {
CompletionCallback transport_callback =
!should_block_ ? callback : base::Bind(&BlockingState::OnCompleted,
base::Unretained(this));
int rv = wrapped_function_.Run(buf, len, transport_callback);
if (should_block_) {
user_callback_ = callback;
have_result_ = (rv != ERR_IO_PENDING);
pending_result_ = rv;
return ERR_IO_PENDING;
}
return rv;
}
void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) {
if (should_block_) {
have_result_ = true;
pending_result_ = result;
return;
}
base::ResetAndReturn(&user_callback_).Run(result);
}
class DeleteSocketCallback : public TestCompletionCallbackBase {
public:
explicit DeleteSocketCallback(StreamSocket* socket)
: socket_(socket),
callback_(base::Bind(&DeleteSocketCallback::OnComplete,
base::Unretained(this))) {}
virtual ~DeleteSocketCallback() {}
const CompletionCallback& callback() const { return callback_; }
private:
void OnComplete(int result) {
if (socket_) {
delete socket_;
socket_ = NULL;
} else {
ADD_FAILURE() << "Deleting socket twice";
}
SetResult(result);
}
StreamSocket* socket_;
CompletionCallback callback_;
DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback);
};
class SSLClientSocketTest : public PlatformTest {
public:
SSLClientSocketTest()
: socket_factory_(ClientSocketFactory::GetDefaultFactory()),
cert_verifier_(new MockCertVerifier),
transport_security_state_(new TransportSecurityState) {
cert_verifier_->set_default_result(OK);
context_.cert_verifier = cert_verifier_.get();
context_.transport_security_state = transport_security_state_.get();
}
protected:
scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<StreamSocket> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config) {
scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
connection->SetSocket(transport_socket.Pass());
return socket_factory_->CreateSSLClientSocket(
connection.Pass(), host_and_port, ssl_config, context_);
}
ClientSocketFactory* socket_factory_;
scoped_ptr<MockCertVerifier> cert_verifier_;
scoped_ptr<TransportSecurityState> transport_security_state_;
SSLClientSocketContext context_;
};
class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest {
protected:
scoped_refptr<SSLCertRequestInfo> GetCertRequest(
SpawnedTestServer::SSLOptions ssl_options) {
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
if (!test_server.Start())
return NULL;
AddressList addr;
if (!test_server.GetAddressList(&addr))
return NULL;
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo();
sock->GetSSLCertRequestInfo(request_info.get());
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
return request_info;
}
};
static bool LogContainsSSLConnectEndEvent(
const CapturingNetLog::CapturedEntryList& log,
int i) {
return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) ||
LogContainsEvent(
log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
}
}
TEST_F(SSLClientSocketTest, Connect) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, ConnectExpired) {
SpawnedTestServer::SSLOptions ssl_options(
SpawnedTestServer::SSLOptions::CERT_EXPIRED);
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(ERR_CERT_DATE_INVALID, rv);
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
}
TEST_F(SSLClientSocketTest, ConnectMismatched) {
SpawnedTestServer::SSLOptions ssl_options(
SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME);
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID);
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, rv);
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
}
TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) {
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
log.GetEntries(&entries);
ExpectLogContainsSomewhere(
entries, 0, NetLog::TYPE_SSL_CONNECT, NetLog::PHASE_END);
EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv);
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
SSLConfig ssl_config = kDefaultSSLConfig;
ssl_config.send_client_cert = true;
ssl_config.client_cert = NULL;
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
SSLInfo ssl_info;
sock->GetSSLInfo(&ssl_info);
EXPECT_FALSE(ssl_info.client_cert_sent);
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, Read) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
scoped_refptr<IOBuffer> request_buffer(
new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
for (;;) {
rv = sock->Read(buf.get(), 4096, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GE(rv, 0);
if (rv <= 0)
break;
}
}
TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
scoped_ptr<SynchronousErrorStreamSocket> transport(
new SynchronousErrorStreamSocket(real_transport.Pass()));
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
SynchronousErrorStreamSocket* raw_transport = transport.get();
scoped_ptr<SSLClientSocket> sock(
CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
test_server.host_port_pair(),
ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
static const int kRequestTextSize =
static_cast<int>(arraysize(request_text) - 1);
scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
memcpy(request_buffer->data(), request_text, kRequestTextSize);
rv = callback.GetResult(
sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
EXPECT_EQ(kRequestTextSize, rv);
raw_transport->SetNextReadError(ERR_CONNECTION_RESET);
scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback()));
#if !defined(USE_OPENSSL)
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
#else
EXPECT_EQ(0, rv);
#endif
}
TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
scoped_ptr<SynchronousErrorStreamSocket> error_socket(
new SynchronousErrorStreamSocket(real_transport.Pass()));
SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
scoped_ptr<FakeBlockingStreamSocket> transport(
new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
scoped_ptr<SSLClientSocket> sock(
CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
test_server.host_port_pair(),
ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
static const int kRequestTextSize =
static_cast<int>(arraysize(request_text) - 1);
scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
memcpy(request_buffer->data(), request_text, kRequestTextSize);
raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
raw_transport->SetNextWriteShouldBlock();
rv = callback.GetResult(
sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
EXPECT_EQ(kRequestTextSize, rv);
scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
rv = sock->Read(buf.get(), 4096, callback.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
raw_transport->UnblockWrite();
rv = callback.GetResult(rv);
#if !defined(USE_OPENSSL)
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
#else
EXPECT_EQ(0, rv);
#endif
}
TEST_F(SSLClientSocketTest, Read_FullDuplex) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
rv = sock->Read(buf.get(), 4096, callback.callback());
ASSERT_EQ(ERR_IO_PENDING, rv);
std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
for (int i = 0; i < 3770; ++i)
request_text.push_back('*');
request_text.append("\r\n\r\n");
scoped_refptr<IOBuffer> request_buffer(new StringIOBuffer(request_text));
TestCompletionCallback callback2;
rv = sock->Write(
request_buffer.get(), request_text.size(), callback2.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback2.WaitForResult();
EXPECT_EQ(static_cast<int>(request_text.size()), rv);
rv = callback.WaitForResult();
EXPECT_GT(rv, 0);
}
TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
scoped_ptr<SynchronousErrorStreamSocket> error_socket(
new SynchronousErrorStreamSocket(real_transport.Pass()));
SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
scoped_ptr<FakeBlockingStreamSocket> transport(
new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
scoped_ptr<SSLClientSocket> sock =
CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
test_server.host_port_pair(),
ssl_config);
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
request_text.append(20 * 1024, '*');
request_text.append("\r\n\r\n");
scoped_refptr<DrainableIOBuffer> request_buffer(new DrainableIOBuffer(
new StringIOBuffer(request_text), request_text.size()));
raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET);
raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
raw_transport->SetNextReadShouldBlock();
raw_transport->SetNextWriteShouldBlock();
SSLClientSocket* raw_sock = sock.get();
DeleteSocketCallback read_callback(sock.release());
scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096));
rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback());
ASSERT_EQ(ERR_IO_PENDING, rv);
ASSERT_FALSE(read_callback.have_result());
#if !defined(USE_OPENSSL)
rv = callback.GetResult(raw_sock->Write(request_buffer.get(),
request_buffer->BytesRemaining(),
callback.callback()));
ASSERT_LT(0, rv);
request_buffer->DidConsume(rv);
ASSERT_LT(0, request_buffer->BytesRemaining());
#endif
rv = raw_sock->Write(request_buffer.get(),
request_buffer->BytesRemaining(),
callback.callback());
ASSERT_EQ(ERR_IO_PENDING, rv);
ASSERT_FALSE(callback.have_result());
raw_transport->UnblockWrite();
rv = read_callback.WaitForResult();
#if !defined(USE_OPENSSL)
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
#else
EXPECT_EQ(0, rv);
#endif
EXPECT_FALSE(callback.have_result());
}
TEST_F(SSLClientSocketTest, Read_WithWriteError) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
scoped_ptr<SynchronousErrorStreamSocket> error_socket(
new SynchronousErrorStreamSocket(real_transport.Pass()));
SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
scoped_ptr<FakeBlockingStreamSocket> transport(
new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
scoped_ptr<SSLClientSocket> sock(
CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
test_server.host_port_pair(),
ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
static const int kRequestTextSize =
static_cast<int>(arraysize(request_text) - 1);
scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
memcpy(request_buffer->data(), request_text, kRequestTextSize);
rv = callback.GetResult(
sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
EXPECT_EQ(kRequestTextSize, rv);
TestCompletionCallback read_callback;
raw_transport->SetNextReadShouldBlock();
scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
rv = sock->Read(buf.get(), 4096, read_callback.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
std::string long_request_text =
"GET / HTTP/1.1\r\nUser-Agent: long browser name ";
long_request_text.append(20 * 1024, '*');
long_request_text.append("\r\n\r\n");
scoped_refptr<DrainableIOBuffer> long_request_buffer(new DrainableIOBuffer(
new StringIOBuffer(long_request_text), long_request_text.size()));
raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
do {
rv = callback.GetResult(sock->Write(long_request_buffer.get(),
long_request_buffer->BytesRemaining(),
callback.callback()));
if (rv > 0) {
long_request_buffer->DidConsume(rv);
ASSERT_LT(0, long_request_buffer->BytesRemaining());
}
} while (rv > 0);
#if !defined(USE_OPENSSL)
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
#else
EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
#endif
raw_transport->UnblockRead();
rv = read_callback.WaitForResult();
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
}
TEST_F(SSLClientSocketTest, Read_SmallChunks) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
scoped_refptr<IOBuffer> request_buffer(
new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
scoped_refptr<IOBuffer> buf(new IOBuffer(1));
for (;;) {
rv = sock->Read(buf.get(), 1, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GE(rv, 0);
if (rv <= 0)
break;
}
}
TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
scoped_ptr<ReadBufferingStreamSocket> transport(
new ReadBufferingStreamSocket(real_transport.Pass()));
ReadBufferingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(
CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
test_server.host_port_pair(),
kDefaultSSLConfig));
rv = callback.GetResult(sock->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
ASSERT_TRUE(sock->IsConnected());
const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n";
scoped_refptr<IOBuffer> request_buffer(
new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = callback.GetResult(sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback()));
ASSERT_GT(rv, 0);
ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
raw_transport->SetBufferSize(15000);
scoped_refptr<IOBuffer> buffer(new IOBuffer(8192));
rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback()));
ASSERT_EQ(rv, 8192);
}
TEST_F(SSLClientSocketTest, Read_Interrupted) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
scoped_refptr<IOBuffer> request_buffer(
new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
scoped_refptr<IOBuffer> buf(new IOBuffer(512));
rv = sock->Read(buf.get(), 512, callback.callback());
EXPECT_TRUE(rv > 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GT(rv, 0);
}
TEST_F(SSLClientSocketTest, Read_FullLogging) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
log.SetLogLevel(NetLog::LOG_ALL);
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
scoped_refptr<IOBuffer> request_buffer(
new IOBuffer(arraysize(request_text) - 1));
memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
rv = sock->Write(
request_buffer.get(), arraysize(request_text) - 1, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
size_t last_index = ExpectLogContainsSomewhereAfter(
entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
for (;;) {
rv = sock->Read(buf.get(), 4096, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_GE(rv, 0);
if (rv <= 0)
break;
log.GetEntries(&entries);
last_index =
ExpectLogContainsSomewhereAfter(entries,
last_index + 1,
NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED,
NetLog::PHASE_NONE);
}
}
TEST_F(SSLClientSocketTest, PrematureApplicationData) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
TestCompletionCallback callback;
static const unsigned char application_data[] = {
0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b,
0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46,
0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6,
0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36,
0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d,
0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f,
0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57,
0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82,
0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02,
0x0a};
MockRead data_reads[] = {
MockRead(SYNCHRONOUS,
reinterpret_cast<const char*>(application_data),
arraysize(application_data)),
MockRead(SYNCHRONOUS, OK), };
StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0);
scoped_ptr<StreamSocket> transport(
new MockTCPClientSocket(addr, NULL, &data));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
}
TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
const uint16 kCiphersToDisable[] = {0x0005,
};
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4;
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
for (size_t i = 0; i < arraysize(kCiphersToDisable); ++i)
ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH ||
rv == ERR_SSL_PROTOCOL_ERROR);
log.GetEntries(&entries);
ExpectLogContainsSomewhere(
entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1) ||
LogContainsSSLConnectEndEvent(entries, -2));
}
TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle());
socket_handle->SetSocket(transport.Pass());
scoped_ptr<SSLClientSocket> sock(
socket_factory_->CreateSSLClientSocket(socket_handle.Pass(),
test_server.host_port_pair(),
kDefaultSSLConfig,
context_));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
}
TEST_F(SSLClientSocketTest, ExportKeyingMaterial) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
const int kKeyingMaterialSize = 32;
const char* kKeyingLabel1 = "client-socket-test-1";
const char* kKeyingContext = "";
unsigned char client_out1[kKeyingMaterialSize];
memset(client_out1, 0, sizeof(client_out1));
rv = sock->ExportKeyingMaterial(
kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1));
EXPECT_EQ(rv, OK);
const char* kKeyingLabel2 = "client-socket-test-2";
unsigned char client_out2[kKeyingMaterialSize];
memset(client_out2, 0, sizeof(client_out2));
rv = sock->ExportKeyingMaterial(
kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2));
EXPECT_EQ(rv, OK);
EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0);
}
TEST(SSLClientSocket, ClearSessionCache) {
SSLClientSocket::ClearSessionCache();
}
TEST_F(SSLClientSocketTest, VerifyServerChainProperlyOrdered) {
cert_verifier_->set_default_result(ERR_CERT_INVALID);
SpawnedTestServer::SSLOptions ssl_options(
SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT);
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
rv = callback.GetResult(rv);
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
rv = callback.GetResult(rv);
EXPECT_EQ(ERR_CERT_INVALID, rv);
EXPECT_TRUE(sock->IsConnected());
CertificateList server_certs =
CreateCertificateListFromFile(GetTestCertsDirectory(),
"redundant-server-chain.pem",
X509Certificate::FORMAT_AUTO);
scoped_refptr<X509Certificate> server_certificate =
sock->GetUnverifiedServerCertificateChain();
const X509Certificate::OSCertHandles& server_intermediates =
server_certificate->GetIntermediateCertificates();
ASSERT_EQ(4U, server_certs.size());
EXPECT_TRUE(X509Certificate::IsSameOSCert(
server_certificate->os_cert_handle(), server_certs[0]->os_cert_handle()));
ASSERT_EQ(3U, server_intermediates.size());
EXPECT_TRUE(X509Certificate::IsSameOSCert(server_intermediates[0],
server_certs[1]->os_cert_handle()));
EXPECT_TRUE(X509Certificate::IsSameOSCert(server_intermediates[1],
server_certs[2]->os_cert_handle()));
EXPECT_TRUE(X509Certificate::IsSameOSCert(server_intermediates[2],
server_certs[3]->os_cert_handle()));
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) {
cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
CertificateList certs =
CreateCertificateListFromFile(GetTestCertsDirectory(),
"redundant-validated-chain.pem",
X509Certificate::FORMAT_AUTO);
ASSERT_EQ(3U, certs.size());
X509Certificate::OSCertHandles temp_intermediates;
temp_intermediates.push_back(certs[1]->os_cert_handle());
temp_intermediates.push_back(certs[2]->os_cert_handle());
CertVerifyResult verify_result;
verify_result.verified_cert = X509Certificate::CreateFromHandle(
certs[0]->os_cert_handle(), temp_intermediates);
cert_verifier_->AddResultForCert(certs[0].get(), verify_result, OK);
scoped_refptr<X509Certificate> root_cert = ImportCertFromFile(
GetTestCertsDirectory(), "redundant-validated-chain-root.pem");
ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert);
ScopedTestRoot scoped_root(root_cert.get());
SpawnedTestServer::SSLOptions ssl_options(
SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT);
SpawnedTestServer test_server(
SpawnedTestServer::TYPE_HTTPS,
ssl_options,
base::FilePath(FILE_PATH_LITERAL("net/data/ssl")));
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
SSLInfo ssl_info;
sock->GetSSLInfo(&ssl_info);
const X509Certificate::OSCertHandles& intermediates =
ssl_info.cert->GetIntermediateCertificates();
ASSERT_EQ(2U, intermediates.size());
EXPECT_TRUE(X509Certificate::IsSameOSCert(ssl_info.cert->os_cert_handle(),
certs[0]->os_cert_handle()));
EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[0],
certs[1]->os_cert_handle()));
EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[1],
certs[2]->os_cert_handle()));
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) {
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
ASSERT_TRUE(request_info.get());
EXPECT_EQ(0u, request_info->cert_authorities.size());
}
TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) {
const base::FilePath::CharType kThawteFile[] =
FILE_PATH_LITERAL("thawte.single.pem");
const unsigned char kThawteDN[] = {
0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a,
0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e,
0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79,
0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03,
0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20,
0x53, 0x47, 0x43, 0x20, 0x43, 0x41};
const size_t kThawteLen = sizeof(kThawteDN);
const base::FilePath::CharType kDiginotarFile[] =
FILE_PATH_LITERAL("diginotar_root_ca.pem");
const unsigned char kDiginotarDN[] = {
0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a,
0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31,
0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69,
0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74,
0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48,
0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f,
0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e,
0x6c};
const size_t kDiginotarLen = sizeof(kDiginotarDN);
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.request_client_certificate = true;
ssl_options.client_authorities.push_back(
GetTestClientCertsDirectory().Append(kThawteFile));
ssl_options.client_authorities.push_back(
GetTestClientCertsDirectory().Append(kDiginotarFile));
scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
ASSERT_TRUE(request_info.get());
ASSERT_EQ(2u, request_info->cert_authorities.size());
EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen),
request_info->cert_authorities[0]);
EXPECT_EQ(
std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen),
request_info->cert_authorities[1]);
}
TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledTLSExtension) {
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.signed_cert_timestamps_tls_ext = "test";
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
ssl_options,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.signed_cert_timestamps_enabled = true;
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
#if !defined(USE_OPENSSL)
EXPECT_TRUE(sock->signed_cert_timestamps_received_);
#else
EXPECT_FALSE(sock->signed_cert_timestamps_received_);
#endif
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) {
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.staple_ocsp_response = true;
ssl_options.server_certificate = SpawnedTestServer::SSLOptions::CERT_AUTO;
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
ssl_options,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.signed_cert_timestamps_enabled = true;
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
#if !defined(USE_OPENSSL)
EXPECT_TRUE(sock->stapled_ocsp_response_received_);
#else
EXPECT_FALSE(sock->stapled_ocsp_response_received_);
#endif
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) {
SpawnedTestServer::SSLOptions ssl_options;
ssl_options.signed_cert_timestamps_tls_ext = "test";
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
ssl_options,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.signed_cert_timestamps_enabled = false;
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
CapturingNetLog::CapturedEntryList entries;
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
log.GetEntries(&entries);
EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
EXPECT_FALSE(sock->signed_cert_timestamps_received_);
sock->Disconnect();
EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, ReuseStates) {
SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
SpawnedTestServer::kLocalhost,
base::FilePath());
ASSERT_TRUE(test_server.Start());
AddressList addr;
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
scoped_ptr<StreamSocket> transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
EXPECT_TRUE(sock->IsConnected());
EXPECT_TRUE(sock->IsConnectedAndIdle());
EXPECT_FALSE(sock->WasEverUsed());
const char kRequestText[] = "GET / HTTP/1.0\r\n\r\n";
const size_t kRequestLen = arraysize(kRequestText) - 1;
scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestLen));
memcpy(request_buffer->data(), kRequestText, kRequestLen);
rv = sock->Write(request_buffer.get(), kRequestLen, callback.callback());
EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(static_cast<int>(kRequestLen), rv);
EXPECT_TRUE(sock->WasEverUsed());
}
}