root/tools/android/forwarder2/socket.cc

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. FamilyIsTCP
  2. BindUnix
  3. BindTcp
  4. ConnectUnix
  5. ConnectTcp
  6. addr_len_
  7. Shutdown
  8. Close
  9. InitSocketInternal
  10. SetNonBlocking
  11. InitUnixSocket
  12. InitTcpSocket
  13. BindAndListen
  14. Accept
  15. Connect
  16. Resolve
  17. GetPort
  18. ReadNumBytes
  19. SetSocketError
  20. Read
  21. NonBlockingRead
  22. Write
  23. NonBlockingWrite
  24. WriteString
  25. AddEventFd
  26. DidReceiveEventOnFd
  27. DidReceiveEvent
  28. WriteNumBytes
  29. WaitForEvent
  30. GetUnixDomainSocketProcessOwner

// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "tools/android/forwarder2/socket.h"

#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>

#include "base/logging.h"
#include "base/posix/eintr_wrapper.h"
#include "base/safe_strerror_posix.h"
#include "tools/android/common/net.h"
#include "tools/android/forwarder2/common.h"

namespace {
const int kNoTimeout = -1;
const int kConnectTimeOut = 10;  // Seconds.

bool FamilyIsTCP(int family) {
  return family == AF_INET || family == AF_INET6;
}
}  // namespace

namespace forwarder2 {

bool Socket::BindUnix(const std::string& path) {
  errno = 0;
  if (!InitUnixSocket(path) || !BindAndListen()) {
    Close();
    return false;
  }
  return true;
}

bool Socket::BindTcp(const std::string& host, int port) {
  errno = 0;
  if (!InitTcpSocket(host, port) || !BindAndListen()) {
    Close();
    return false;
  }
  return true;
}

bool Socket::ConnectUnix(const std::string& path) {
  errno = 0;
  if (!InitUnixSocket(path) || !Connect()) {
    Close();
    return false;
  }
  return true;
}

bool Socket::ConnectTcp(const std::string& host, int port) {
  errno = 0;
  if (!InitTcpSocket(host, port) || !Connect()) {
    Close();
    return false;
  }
  return true;
}

Socket::Socket()
    : socket_(-1),
      port_(0),
      socket_error_(false),
      family_(AF_INET),
      addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
      addr_len_(sizeof(sockaddr)) {
  memset(&addr_, 0, sizeof(addr_));
}

Socket::~Socket() {
  Close();
}

void Socket::Shutdown() {
  if (!IsClosed()) {
    PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
  }
}

void Socket::Close() {
  if (!IsClosed()) {
    CloseFD(socket_);
    socket_ = -1;
  }
}

bool Socket::InitSocketInternal() {
  socket_ = socket(family_, SOCK_STREAM, 0);
  if (socket_ < 0) {
    PLOG(ERROR) << "socket";
    return false;
  }
  tools::DisableNagle(socket_);
  int reuse_addr = 1;
  setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
             sizeof(reuse_addr));
  if (!SetNonBlocking())
    return false;
  return true;
}

bool Socket::SetNonBlocking() {
  const int flags = fcntl(socket_, F_GETFL);
  if (flags < 0) {
    PLOG(ERROR) << "fcntl";
    return false;
  }
  if (flags & O_NONBLOCK)
    return true;
  if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
    PLOG(ERROR) << "fcntl";
    return false;
  }
  return true;
}

bool Socket::InitUnixSocket(const std::string& path) {
  static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
  // For abstract sockets we need one extra byte for the leading zero.
  if (path.size() + 2 /* '\0' */ > kPathMax) {
    LOG(ERROR) << "The provided path is too big to create a unix "
               << "domain socket: " << path;
    return false;
  }
  family_ = PF_UNIX;
  addr_.addr_un.sun_family = family_;
  // Copied from net/socket/unix_domain_socket_posix.cc
  // Convert the path given into abstract socket name. It must start with
  // the '\0' character, so we are adding it. |addr_len| must specify the
  // length of the structure exactly, as potentially the socket name may
  // have '\0' characters embedded (although we don't support this).
  // Note that addr_.addr_un.sun_path is already zero initialized.
  memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
  addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
  addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
  return InitSocketInternal();
}

bool Socket::InitTcpSocket(const std::string& host, int port) {
  port_ = port;
  if (host.empty()) {
    // Use localhost: INADDR_LOOPBACK
    family_ = AF_INET;
    addr_.addr4.sin_family = family_;
    addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
  } else if (!Resolve(host)) {
    return false;
  }
  CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
  if (family_ == AF_INET) {
    addr_.addr4.sin_port = htons(port_);
    addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
    addr_len_ = sizeof(addr_.addr4);
  } else if (family_ == AF_INET6) {
    addr_.addr6.sin6_port = htons(port_);
    addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
    addr_len_ = sizeof(addr_.addr6);
  }
  return InitSocketInternal();
}

bool Socket::BindAndListen() {
  errno = 0;
  if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
      HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
    PLOG(ERROR) << "bind/listen";
    SetSocketError();
    return false;
  }
  if (port_ == 0 && FamilyIsTCP(family_)) {
    SockAddr addr;
    memset(&addr, 0, sizeof(addr));
    socklen_t addrlen = 0;
    sockaddr* addr_ptr = NULL;
    uint16* port_ptr = NULL;
    if (family_ == AF_INET) {
      addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
      port_ptr = &addr.addr4.sin_port;
      addrlen = sizeof(addr.addr4);
    } else if (family_ == AF_INET6) {
      addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
      port_ptr = &addr.addr6.sin6_port;
      addrlen = sizeof(addr.addr6);
    }
    errno = 0;
    if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
      PLOG(ERROR) << "getsockname";
      SetSocketError();
      return false;
    }
    port_ = ntohs(*port_ptr);
  }
  return true;
}

bool Socket::Accept(Socket* new_socket) {
  DCHECK(new_socket != NULL);
  if (!WaitForEvent(READ, kNoTimeout)) {
    SetSocketError();
    return false;
  }
  errno = 0;
  int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
  if (new_socket_fd < 0) {
    SetSocketError();
    return false;
  }
  tools::DisableNagle(new_socket_fd);
  new_socket->socket_ = new_socket_fd;
  if (!new_socket->SetNonBlocking())
    return false;
  return true;
}

bool Socket::Connect() {
  DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
  errno = 0;
  if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
      errno != EINPROGRESS) {
    SetSocketError();
    return false;
  }
  // Wait for connection to complete, or receive a notification.
  if (!WaitForEvent(WRITE, kConnectTimeOut)) {
    SetSocketError();
    return false;
  }
  int socket_errno;
  socklen_t opt_len = sizeof(socket_errno);
  if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
    PLOG(ERROR) << "getsockopt()";
    SetSocketError();
    return false;
  }
  if (socket_errno != 0) {
    LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
    SetSocketError();
    return false;
  }
  return true;
}

bool Socket::Resolve(const std::string& host) {
  struct addrinfo hints;
  struct addrinfo* res;
  memset(&hints, 0, sizeof(hints));
  hints.ai_family = AF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;
  hints.ai_flags |= AI_CANONNAME;

  int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
  if (errcode != 0) {
    errno = 0;
    SetSocketError();
    freeaddrinfo(res);
    return false;
  }
  family_ = res->ai_family;
  switch (res->ai_family) {
    case AF_INET:
      memcpy(&addr_.addr4,
             reinterpret_cast<sockaddr_in*>(res->ai_addr),
             sizeof(sockaddr_in));
      break;
    case AF_INET6:
      memcpy(&addr_.addr6,
             reinterpret_cast<sockaddr_in6*>(res->ai_addr),
             sizeof(sockaddr_in6));
      break;
  }
  freeaddrinfo(res);
  return true;
}

int Socket::GetPort() {
  if (!FamilyIsTCP(family_)) {
    LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
    return 0;
  }
  return port_;
}

int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
  int bytes_read = 0;
  int ret = 1;
  while (bytes_read < num_bytes && ret > 0) {
    ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
    if (ret >= 0)
      bytes_read += ret;
  }
  return bytes_read;
}

void Socket::SetSocketError() {
  socket_error_ = true;
  DCHECK_NE(EAGAIN, errno);
  DCHECK_NE(EWOULDBLOCK, errno);
  Close();
}

int Socket::Read(void* buffer, size_t buffer_size) {
  if (!WaitForEvent(READ, kNoTimeout)) {
    SetSocketError();
    return 0;
  }
  int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
  if (ret < 0) {
    PLOG(ERROR) << "read";
    SetSocketError();
  }
  return ret;
}

int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
  DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
  int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
  if (ret < 0) {
    PLOG(ERROR) << "read";
    SetSocketError();
  }
  return ret;
}

int Socket::Write(const void* buffer, size_t count) {
  if (!WaitForEvent(WRITE, kNoTimeout)) {
    SetSocketError();
    return 0;
  }
  int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
  if (ret < 0) {
    PLOG(ERROR) << "send";
    SetSocketError();
  }
  return ret;
}

int Socket::NonBlockingWrite(const void* buffer, size_t count) {
  DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
  int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
  if (ret < 0) {
    PLOG(ERROR) << "send";
    SetSocketError();
  }
  return ret;
}

int Socket::WriteString(const std::string& buffer) {
  return WriteNumBytes(buffer.c_str(), buffer.size());
}

void Socket::AddEventFd(int event_fd) {
  Event event;
  event.fd = event_fd;
  event.was_fired = false;
  events_.push_back(event);
}

bool Socket::DidReceiveEventOnFd(int fd) const {
  for (size_t i = 0; i < events_.size(); ++i)
    if (events_[i].fd == fd)
      return events_[i].was_fired;
  return false;
}

bool Socket::DidReceiveEvent() const {
  for (size_t i = 0; i < events_.size(); ++i)
    if (events_[i].was_fired)
      return true;
  return false;
}

int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
  int bytes_written = 0;
  int ret = 1;
  while (bytes_written < num_bytes && ret > 0) {
    ret = Write(static_cast<const char*>(buffer) + bytes_written,
                num_bytes - bytes_written);
    if (ret >= 0)
      bytes_written += ret;
  }
  return bytes_written;
}

bool Socket::WaitForEvent(EventType type, int timeout_secs) {
  if (socket_ == -1)
    return true;
  DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
  fd_set read_fds;
  fd_set write_fds;
  FD_ZERO(&read_fds);
  FD_ZERO(&write_fds);
  if (type == READ)
    FD_SET(socket_, &read_fds);
  else
    FD_SET(socket_, &write_fds);
  for (size_t i = 0; i < events_.size(); ++i)
    FD_SET(events_[i].fd, &read_fds);
  timeval tv = {};
  timeval* tv_ptr = NULL;
  if (timeout_secs > 0) {
    tv.tv_sec = timeout_secs;
    tv.tv_usec = 0;
    tv_ptr = &tv;
  }
  int max_fd = socket_;
  for (size_t i = 0; i < events_.size(); ++i)
    if (events_[i].fd > max_fd)
      max_fd = events_[i].fd;
  if (HANDLE_EINTR(
          select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
    PLOG(ERROR) << "select";
    return false;
  }
  bool event_was_fired = false;
  for (size_t i = 0; i < events_.size(); ++i) {
    if (FD_ISSET(events_[i].fd, &read_fds)) {
      events_[i].was_fired = true;
      event_was_fired = true;
    }
  }
  return !event_was_fired;
}

// static
pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
  Socket socket;
  if (!socket.ConnectUnix(path))
    return -1;
  ucred ucred;
  socklen_t len = sizeof(ucred);
  if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
    CHECK_NE(ENOPROTOOPT, errno);
    return -1;
  }
  return ucred.pid;
}

}  // namespace forwarder2

/* [<][>][^][v][top][bottom][index][help] */