root/google_apis/gcm/engine/connection_handler_impl.cc

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. weak_ptr_factory_
  2. Init
  3. Reset
  4. CanSendMessage
  5. SendMessage
  6. Login
  7. OnMessageSent
  8. GetNextMessage
  9. WaitForData
  10. OnGotVersion
  11. OnGotMessageTag
  12. OnGotMessageSize
  13. OnGotMessageBytes
  14. OnTimeout
  15. CloseConnection

// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "google_apis/gcm/engine/connection_handler_impl.h"

#include "base/message_loop/message_loop.h"
#include "google/protobuf/io/coded_stream.h"
#include "google_apis/gcm/base/mcs_util.h"
#include "google_apis/gcm/base/socket_stream.h"
#include "google_apis/gcm/protocol/mcs.pb.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"

using namespace google::protobuf::io;

namespace gcm {

namespace {

// # of bytes a MCS version packet consumes.
const int kVersionPacketLen = 1;
// # of bytes a tag packet consumes.
const int kTagPacketLen = 1;
// Max # of bytes a length packet consumes.
const int kSizePacketLenMin = 1;
const int kSizePacketLenMax = 2;

// The current MCS protocol version.
const int kMCSVersion = 41;

}  // namespace

ConnectionHandlerImpl::ConnectionHandlerImpl(
    base::TimeDelta read_timeout,
    const ProtoReceivedCallback& read_callback,
    const ProtoSentCallback& write_callback,
    const ConnectionChangedCallback& connection_callback)
    : read_timeout_(read_timeout),
      socket_(NULL),
      handshake_complete_(false),
      message_tag_(0),
      message_size_(0),
      read_callback_(read_callback),
      write_callback_(write_callback),
      connection_callback_(connection_callback),
      weak_ptr_factory_(this) {
}

ConnectionHandlerImpl::~ConnectionHandlerImpl() {
}

void ConnectionHandlerImpl::Init(
    const mcs_proto::LoginRequest& login_request,
    net::StreamSocket* socket) {
  DCHECK(!read_callback_.is_null());
  DCHECK(!write_callback_.is_null());
  DCHECK(!connection_callback_.is_null());

  // Invalidate any previously outstanding reads.
  weak_ptr_factory_.InvalidateWeakPtrs();

  handshake_complete_ = false;
  message_tag_ = 0;
  message_size_ = 0;
  socket_ = socket;
  input_stream_.reset(new SocketInputStream(socket_));
  output_stream_.reset(new SocketOutputStream(socket_));

  Login(login_request);
}

void ConnectionHandlerImpl::Reset() {
  CloseConnection();
}

bool ConnectionHandlerImpl::CanSendMessage() const {
  return handshake_complete_ && output_stream_.get() &&
      output_stream_->GetState() == SocketOutputStream::EMPTY;
}

void ConnectionHandlerImpl::SendMessage(
    const google::protobuf::MessageLite& message) {
  DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
  DCHECK(handshake_complete_);

  {
    CodedOutputStream coded_output_stream(output_stream_.get());
    DVLOG(1) << "Writing proto of size " << message.ByteSize();
    int tag = GetMCSProtoTag(message);
    DCHECK_NE(tag, -1);
    coded_output_stream.WriteRaw(&tag, 1);
    coded_output_stream.WriteVarint32(message.ByteSize());
    message.SerializeToCodedStream(&coded_output_stream);
  }

  if (output_stream_->Flush(
          base::Bind(&ConnectionHandlerImpl::OnMessageSent,
                     weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
    OnMessageSent();
  }
}

void ConnectionHandlerImpl::Login(
    const google::protobuf::MessageLite& login_request) {
  DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);

  const char version_byte[1] = {kMCSVersion};
  const char login_request_tag[1] = {kLoginRequestTag};
  {
    CodedOutputStream coded_output_stream(output_stream_.get());
    coded_output_stream.WriteRaw(version_byte, 1);
    coded_output_stream.WriteRaw(login_request_tag, 1);
    coded_output_stream.WriteVarint32(login_request.ByteSize());
    login_request.SerializeToCodedStream(&coded_output_stream);
  }

  if (output_stream_->Flush(
          base::Bind(&ConnectionHandlerImpl::OnMessageSent,
                     weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
    base::MessageLoop::current()->PostTask(
        FROM_HERE,
        base::Bind(&ConnectionHandlerImpl::OnMessageSent,
                   weak_ptr_factory_.GetWeakPtr()));
  }

  read_timeout_timer_.Start(FROM_HERE,
                            read_timeout_,
                            base::Bind(&ConnectionHandlerImpl::OnTimeout,
                                       weak_ptr_factory_.GetWeakPtr()));
  WaitForData(MCS_VERSION_TAG_AND_SIZE);
}

void ConnectionHandlerImpl::OnMessageSent() {
  if (!output_stream_.get()) {
    // The connection has already been closed. Just return.
    DCHECK(!input_stream_.get());
    DCHECK(!read_timeout_timer_.IsRunning());
    return;
  }

  if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
    int last_error = output_stream_->last_error();
    CloseConnection();
    // If the socket stream had an error, plumb it up, else plumb up FAILED.
    if (last_error == net::OK)
      last_error = net::ERR_FAILED;
    connection_callback_.Run(last_error);
    return;
  }

  write_callback_.Run();
}

void ConnectionHandlerImpl::GetNextMessage() {
  DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
         SocketInputStream::READY == input_stream_->GetState());
  message_tag_ = 0;
  message_size_ = 0;

  WaitForData(MCS_TAG_AND_SIZE);
}

void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
  DVLOG(1) << "Waiting for MCS data: state == " << state;

  if (!input_stream_) {
    // The connection has already been closed. Just return.
    DCHECK(!output_stream_.get());
    DCHECK(!read_timeout_timer_.IsRunning());
    return;
  }

  if (input_stream_->GetState() != SocketInputStream::EMPTY &&
      input_stream_->GetState() != SocketInputStream::READY) {
    // An error occurred.
    int last_error = output_stream_->last_error();
    CloseConnection();
    // If the socket stream had an error, plumb it up, else plumb up FAILED.
    if (last_error == net::OK)
      last_error = net::ERR_FAILED;
    connection_callback_.Run(last_error);
    return;
  }

  // Used to determine whether a Socket::Read is necessary.
  int min_bytes_needed = 0;
  // Used to limit the size of the Socket::Read.
  int max_bytes_needed = 0;

  switch(state) {
    case MCS_VERSION_TAG_AND_SIZE:
      min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
      max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
      break;
    case MCS_TAG_AND_SIZE:
      min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
      max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
      break;
    case MCS_FULL_SIZE:
      // If in this state, the minimum size packet length must already have been
      // insufficient, so set both to the max length.
      min_bytes_needed = kSizePacketLenMax;
      max_bytes_needed = kSizePacketLenMax;
      break;
    case MCS_PROTO_BYTES:
      read_timeout_timer_.Reset();
      // No variability in the message size, set both to the same.
      min_bytes_needed = message_size_;
      max_bytes_needed = message_size_;
      break;
    default:
      NOTREACHED();
  }
  DCHECK_GE(max_bytes_needed, min_bytes_needed);

  int byte_count = input_stream_->UnreadByteCount();
  if (min_bytes_needed - byte_count > 0 &&
      input_stream_->Refresh(
          base::Bind(&ConnectionHandlerImpl::WaitForData,
                     weak_ptr_factory_.GetWeakPtr(),
                     state),
          max_bytes_needed - byte_count) == net::ERR_IO_PENDING) {
    return;
  }

  // Check for refresh errors.
  if (input_stream_->GetState() != SocketInputStream::READY) {
    // An error occurred.
    int last_error = output_stream_->last_error();
    CloseConnection();
    // If the socket stream had an error, plumb it up, else plumb up FAILED.
    if (last_error == net::OK)
      last_error = net::ERR_FAILED;
    connection_callback_.Run(last_error);
    return;
  }

  // Received enough bytes, process them.
  DVLOG(1) << "Processing MCS data: state == " << state;
  switch(state) {
    case MCS_VERSION_TAG_AND_SIZE:
      OnGotVersion();
      break;
    case MCS_TAG_AND_SIZE:
      OnGotMessageTag();
      break;
    case MCS_FULL_SIZE:
      OnGotMessageSize();
      break;
    case MCS_PROTO_BYTES:
      OnGotMessageBytes();
      break;
    default:
      NOTREACHED();
  }
}

void ConnectionHandlerImpl::OnGotVersion() {
  uint8 version = 0;
  {
    CodedInputStream coded_input_stream(input_stream_.get());
    coded_input_stream.ReadRaw(&version, 1);
  }
  // TODO(zea): remove this when the server is ready.
  if (version < kMCSVersion && version != 38) {
    LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
    connection_callback_.Run(net::ERR_FAILED);
    return;
  }

  input_stream_->RebuildBuffer();

  // Process the LoginResponse message tag.
  OnGotMessageTag();
}

void ConnectionHandlerImpl::OnGotMessageTag() {
  if (input_stream_->GetState() != SocketInputStream::READY) {
    LOG(ERROR) << "Failed to receive protobuf tag.";
    read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
    return;
  }

  {
    CodedInputStream coded_input_stream(input_stream_.get());
    coded_input_stream.ReadRaw(&message_tag_, 1);
  }

  DVLOG(1) << "Received proto of type "
           << static_cast<unsigned int>(message_tag_);

  if (!read_timeout_timer_.IsRunning()) {
    read_timeout_timer_.Start(FROM_HERE,
                              read_timeout_,
                              base::Bind(&ConnectionHandlerImpl::OnTimeout,
                                         weak_ptr_factory_.GetWeakPtr()));
  }
  OnGotMessageSize();
}

void ConnectionHandlerImpl::OnGotMessageSize() {
  if (input_stream_->GetState() != SocketInputStream::READY) {
    LOG(ERROR) << "Failed to receive message size.";
    read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
    return;
  }

  bool need_another_byte = false;
  int prev_byte_count = input_stream_->ByteCount();
  {
    CodedInputStream coded_input_stream(input_stream_.get());
    if (!coded_input_stream.ReadVarint32(&message_size_))
      need_another_byte = true;
  }

  if (need_another_byte) {
    DVLOG(1) << "Expecting another message size byte.";
    if (prev_byte_count >= kSizePacketLenMax) {
      // Already had enough bytes, something else went wrong.
      LOG(ERROR) << "Failed to process message size.";
      read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
      return;
    }
    // Back up by the amount read (should always be 1 byte).
    int bytes_read = prev_byte_count - input_stream_->ByteCount();
    DCHECK_EQ(bytes_read, 1);
    input_stream_->BackUp(bytes_read);
    WaitForData(MCS_FULL_SIZE);
    return;
  }

  DVLOG(1) << "Proto size: " << message_size_;

  if (message_size_ > 0)
    WaitForData(MCS_PROTO_BYTES);
  else
    OnGotMessageBytes();
}

void ConnectionHandlerImpl::OnGotMessageBytes() {
  read_timeout_timer_.Stop();
  scoped_ptr<google::protobuf::MessageLite> protobuf(
      BuildProtobufFromTag(message_tag_));
  // Messages with no content are valid; just use the default protobuf for
  // that tag.
  if (protobuf.get() && message_size_ == 0) {
    base::MessageLoop::current()->PostTask(
        FROM_HERE,
        base::Bind(&ConnectionHandlerImpl::GetNextMessage,
                   weak_ptr_factory_.GetWeakPtr()));
    read_callback_.Run(protobuf.Pass());
    return;
  }

  if (!protobuf.get() ||
      input_stream_->GetState() != SocketInputStream::READY) {
    LOG(ERROR) << "Failed to extract protobuf bytes of type "
               << static_cast<unsigned int>(message_tag_);
    // Reset the connection.
    connection_callback_.Run(net::ERR_FAILED);
    return;
  }

  {
    CodedInputStream coded_input_stream(input_stream_.get());
    if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
      LOG(ERROR) << "Unable to parse GCM message of type "
                 << static_cast<unsigned int>(message_tag_);
      // Reset the connection.
      connection_callback_.Run(net::ERR_FAILED);
      return;
    }
  }

  input_stream_->RebuildBuffer();
  base::MessageLoop::current()->PostTask(
      FROM_HERE,
      base::Bind(&ConnectionHandlerImpl::GetNextMessage,
                 weak_ptr_factory_.GetWeakPtr()));
  if (message_tag_ == kLoginResponseTag) {
    if (handshake_complete_) {
      LOG(ERROR) << "Unexpected login response.";
    } else {
      handshake_complete_ = true;
      DVLOG(1) << "GCM Handshake complete.";
      connection_callback_.Run(net::OK);
    }
  }
  read_callback_.Run(protobuf.Pass());
}

void ConnectionHandlerImpl::OnTimeout() {
  LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
  CloseConnection();
  connection_callback_.Run(net::ERR_TIMED_OUT);
}

void ConnectionHandlerImpl::CloseConnection() {
  DVLOG(1) << "Closing connection.";
  read_timeout_timer_.Stop();
  if (socket_)
    socket_->Disconnect();
  socket_ = NULL;
  handshake_complete_ = false;
  message_tag_ = 0;
  message_size_ = 0;
  input_stream_.reset();
  output_stream_.reset();
  weak_ptr_factory_.InvalidateWeakPtrs();
}

}  // namespace gcm

/* [<][>][^][v][top][bottom][index][help] */