root/chrome/utility/local_discovery/service_discovery_message_handler.cc

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. CreateSocket
  2. CreateSockets
  3. AddSocket
  4. Reset
  5. ClosePlatformSocket
  6. StaticInitializeSocketFactory
  7. ClosePlatformSocket
  8. StaticInitializeSocketFactory
  9. SendHostMessageOnUtilityThread
  10. WatcherUpdateToString
  11. ResolverStatusToString
  12. PreSandboxStartup
  13. InitializeMdns
  14. InitializeThread
  15. OnMessageReceived
  16. PostTask
  17. OnSetSockets
  18. OnStartWatcher
  19. OnDiscoverServices
  20. OnSetActivelyRefreshServices
  21. OnDestroyWatcher
  22. OnResolveService
  23. OnDestroyResolver
  24. OnResolveLocalDomain
  25. OnDestroyLocalDomainResolver
  26. StartWatcher
  27. DiscoverServices
  28. SetActivelyRefreshServices
  29. DestroyWatcher
  30. ResolveService
  31. DestroyResolver
  32. ResolveLocalDomain
  33. DestroyLocalDomainResolver
  34. ShutdownLocalDiscovery
  35. ShutdownOnIOThread
  36. OnServiceUpdated
  37. OnServiceResolved
  38. OnLocalDomainResolved
  39. Send

// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/utility/local_discovery/service_discovery_message_handler.h"

#include <algorithm>

#include "base/lazy_instance.h"
#include "chrome/common/local_discovery/local_discovery_messages.h"
#include "chrome/utility/local_discovery/service_discovery_client_impl.h"
#include "content/public/utility/utility_thread.h"
#include "net/socket/socket_descriptor.h"
#include "net/udp/datagram_server_socket.h"

namespace local_discovery {

namespace {

void ClosePlatformSocket(net::SocketDescriptor socket);

// Sets socket factory used by |net::CreatePlatformSocket|. Implemetation
// keeps single socket that will be returned to the first call to
// |net::CreatePlatformSocket| during object lifetime.
class ScopedSocketFactory : public net::PlatformSocketFactory {
 public:
  explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) {
    net::PlatformSocketFactory::SetInstance(this);
  }

  virtual ~ScopedSocketFactory() {
    net::PlatformSocketFactory::SetInstance(NULL);
    ClosePlatformSocket(socket_);
    socket_ = net::kInvalidSocket;
  }

  virtual net::SocketDescriptor CreateSocket(int family, int type,
                                             int protocol) OVERRIDE {
    DCHECK_EQ(type, SOCK_DGRAM);
    DCHECK(family == AF_INET || family == AF_INET6);
    net::SocketDescriptor result = net::kInvalidSocket;
    std::swap(result, socket_);
    return result;
  }

 private:
  net::SocketDescriptor socket_;
  DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory);
};

struct SocketInfo {
  SocketInfo(net::SocketDescriptor socket,
             net::AddressFamily address_family,
             uint32 interface_index)
      : socket(socket),
        address_family(address_family),
        interface_index(interface_index) {
  }
  net::SocketDescriptor socket;
  net::AddressFamily address_family;
  uint32 interface_index;
};

// Returns list of sockets preallocated before.
class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory {
 public:
  PreCreatedMDnsSocketFactory() {}
  virtual ~PreCreatedMDnsSocketFactory() {
    // Not empty if process exits too fast, before starting mDns code. If
    // happened, destructors may crash accessing destroyed global objects.
    sockets_.weak_clear();
  }

  // net::MDnsSocketFactory implementation:
  virtual void CreateSockets(
      ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
    sockets->swap(sockets_);
    Reset();
  }

  void AddSocket(const SocketInfo& socket_info) {
    // Takes ownership of socket_info.socket;
    ScopedSocketFactory platform_factory(socket_info.socket);
    scoped_ptr<net::DatagramServerSocket> socket(
        net::CreateAndBindMDnsSocket(socket_info.address_family,
                                     socket_info.interface_index));
    if (socket) {
      socket->DetachFromThread();
      sockets_.push_back(socket.release());
    }
  }

  void Reset() {
    sockets_.clear();
  }

 private:
  ScopedVector<net::DatagramServerSocket> sockets_;

  DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory);
};

base::LazyInstance<PreCreatedMDnsSocketFactory>
    g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER;

#if defined(OS_WIN)

void ClosePlatformSocket(net::SocketDescriptor socket) {
  ::closesocket(socket);
}

void StaticInitializeSocketFactory() {
  net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind());
  for (size_t i = 0; i < interfaces.size(); ++i) {
    DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 ||
           interfaces[i].second == net::ADDRESS_FAMILY_IPV6);
    net::SocketDescriptor descriptor =
        net::CreatePlatformSocket(
            net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM,
                                      IPPROTO_UDP);
    g_local_discovery_socket_factory.Get().AddSocket(
        SocketInfo(descriptor, interfaces[i].second, interfaces[i].first));
  }
}

#else  // OS_WIN

void ClosePlatformSocket(net::SocketDescriptor socket) {
  ::close(socket);
}

void StaticInitializeSocketFactory() {
}

#endif  // OS_WIN

void SendHostMessageOnUtilityThread(IPC::Message* msg) {
  content::UtilityThread::Get()->Send(msg);
}

std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) {
  switch (update) {
    case ServiceWatcher::UPDATE_ADDED:
      return "UPDATE_ADDED";
    case ServiceWatcher::UPDATE_CHANGED:
      return "UPDATE_CHANGED";
    case ServiceWatcher::UPDATE_REMOVED:
      return "UPDATE_REMOVED";
    case ServiceWatcher::UPDATE_INVALIDATED:
      return "UPDATE_INVALIDATED";
  }
  return "Unknown Update";
}

std::string ResolverStatusToString(ServiceResolver::RequestStatus status) {
  switch (status) {
    case ServiceResolver::STATUS_SUCCESS:
      return "STATUS_SUCESS";
    case ServiceResolver::STATUS_REQUEST_TIMEOUT:
      return "STATUS_REQUEST_TIMEOUT";
    case ServiceResolver::STATUS_KNOWN_NONEXISTENT:
      return "STATUS_KNOWN_NONEXISTENT";
  }
  return "Unknown Status";
}

}  // namespace

ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
}

ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
  DCHECK(!discovery_thread_);
}

void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
  StaticInitializeSocketFactory();
}

void ServiceDiscoveryMessageHandler::InitializeMdns() {
  if (service_discovery_client_ || mdns_client_)
    return;

  mdns_client_ = net::MDnsClient::CreateDefault();
  bool result =
      mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer());
  // Close unused sockets.
  g_local_discovery_socket_factory.Get().Reset();
  if (!result) {
    VLOG(1) << "Failed to start MDnsClient";
    Send(new LocalDiscoveryHostMsg_Error());
    return;
  }

  service_discovery_client_.reset(
      new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get()));
}

bool ServiceDiscoveryMessageHandler::InitializeThread() {
  if (discovery_task_runner_)
    return true;
  if (discovery_thread_)
    return false;
  utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy();
  discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread"));
  base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0);
  if (discovery_thread_->StartWithOptions(thread_options)) {
    discovery_task_runner_ = discovery_thread_->message_loop_proxy();
    discovery_task_runner_->PostTask(FROM_HERE,
        base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns,
                    base::Unretained(this)));
  }
  return discovery_task_runner_ != NULL;
}

bool ServiceDiscoveryMessageHandler::OnMessageReceived(
    const IPC::Message& message) {
  bool handled = true;
  IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message)
#if defined(OS_POSIX)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets)
#endif  // OS_POSIX
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices,
                        OnSetActivelyRefreshServices)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain,
                        OnResolveLocalDomain)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver,
                        OnDestroyLocalDomainResolver)
    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery,
                        ShutdownLocalDiscovery)
    IPC_MESSAGE_UNHANDLED(handled = false)
  IPC_END_MESSAGE_MAP()
  return handled;
}

void ServiceDiscoveryMessageHandler::PostTask(
    const tracked_objects::Location& from_here,
    const base::Closure& task) {
  if (!InitializeThread())
    return;
  discovery_task_runner_->PostTask(from_here, task);
}

#if defined(OS_POSIX)
void ServiceDiscoveryMessageHandler::OnSetSockets(
    const std::vector<LocalDiscoveryMsg_SocketInfo>& sockets) {
  for (size_t i = 0; i < sockets.size(); ++i) {
    g_local_discovery_socket_factory.Get().AddSocket(
        SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family,
                   sockets[i].interface_index));
  }
}
#endif  // OS_POSIX

void ServiceDiscoveryMessageHandler::OnStartWatcher(
    uint64 id,
    const std::string& service_type) {
  PostTask(FROM_HERE,
           base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher,
                      base::Unretained(this), id, service_type));
}

void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id,
                                                        bool force_update) {
  PostTask(FROM_HERE,
           base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices,
                      base::Unretained(this), id, force_update));
}

void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices(
    uint64 id, bool actively_refresh_services) {
  PostTask(FROM_HERE,
           base::Bind(
               &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices,
               base::Unretained(this), id, actively_refresh_services));
}

void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) {
  PostTask(FROM_HERE,
           base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher,
                      base::Unretained(this), id));
}

void ServiceDiscoveryMessageHandler::OnResolveService(
    uint64 id,
    const std::string& service_name) {
  PostTask(FROM_HERE,
           base::Bind(&ServiceDiscoveryMessageHandler::ResolveService,
                      base::Unretained(this), id, service_name));
}

void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) {
  PostTask(FROM_HERE,
           base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver,
                      base::Unretained(this), id));
}

void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
    uint64 id, const std::string& domain,
    net::AddressFamily address_family) {
    PostTask(FROM_HERE,
           base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain,
                      base::Unretained(this), id, domain, address_family));
}

void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) {
  PostTask(FROM_HERE,
           base::Bind(
               &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver,
               base::Unretained(this), id));
}

void ServiceDiscoveryMessageHandler::StartWatcher(
    uint64 id,
    const std::string& service_type) {
  VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type;
  if (!service_discovery_client_)
    return;
  DCHECK(!ContainsKey(service_watchers_, id));
  scoped_ptr<ServiceWatcher> watcher(
      service_discovery_client_->CreateServiceWatcher(
          service_type,
          base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated,
                     base::Unretained(this), id)));
  watcher->Start();
  service_watchers_[id].reset(watcher.release());
}

void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id,
                                                      bool force_update) {
  VLOG(1) << "DiscoverServices, id=" << id;
  if (!service_discovery_client_)
    return;
  DCHECK(ContainsKey(service_watchers_, id));
  service_watchers_[id]->DiscoverNewServices(force_update);
}

void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices(
    uint64 id,
    bool actively_refresh_services) {
  VLOG(1) << "ActivelyRefreshServices, id=" << id;
  if (!service_discovery_client_)
    return;
  DCHECK(ContainsKey(service_watchers_, id));
  service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services);
}

void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) {
  VLOG(1) << "DestoryWatcher, id=" << id;
  if (!service_discovery_client_)
    return;
  service_watchers_.erase(id);
}

void ServiceDiscoveryMessageHandler::ResolveService(
    uint64 id,
    const std::string& service_name) {
  VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name;
  if (!service_discovery_client_)
    return;
  DCHECK(!ContainsKey(service_resolvers_, id));
  scoped_ptr<ServiceResolver> resolver(
      service_discovery_client_->CreateServiceResolver(
          service_name,
          base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved,
                     base::Unretained(this), id)));
  resolver->StartResolving();
  service_resolvers_[id].reset(resolver.release());
}

void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) {
  VLOG(1) << "DestroyResolver, id=" << id;
  if (!service_discovery_client_)
    return;
  service_resolvers_.erase(id);
}

void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
    uint64 id,
    const std::string& domain,
    net::AddressFamily address_family) {
  VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain;
  if (!service_discovery_client_)
    return;
  DCHECK(!ContainsKey(local_domain_resolvers_, id));
  scoped_ptr<LocalDomainResolver> resolver(
      service_discovery_client_->CreateLocalDomainResolver(
          domain, address_family,
          base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved,
                     base::Unretained(this), id)));
  resolver->Start();
  local_domain_resolvers_[id].reset(resolver.release());
}

void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) {
  VLOG(1) << "DestroyLocalDomainResolver, id=" << id;
  if (!service_discovery_client_)
    return;
  local_domain_resolvers_.erase(id);
}

void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
  if (!discovery_task_runner_)
    return;

  discovery_task_runner_->PostTask(
      FROM_HERE,
      base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread,
                 base::Unretained(this)));

  // This will wait for message loop to drain, so ShutdownOnIOThread will
  // definitely be called.
  discovery_thread_.reset();
}

void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
  VLOG(1) << "ShutdownLocalDiscovery";
  service_watchers_.clear();
  service_resolvers_.clear();
  local_domain_resolvers_.clear();
  service_discovery_client_.reset();
  mdns_client_.reset();
}

void ServiceDiscoveryMessageHandler::OnServiceUpdated(
    uint64 id,
    ServiceWatcher::UpdateType update,
    const std::string& name) {
  VLOG(1) << "OnServiceUpdated, id=" << id
          << ", status=" << WatcherUpdateToString(update) << ", name=" << name;
  DCHECK(service_discovery_client_);

  Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name));
}

void ServiceDiscoveryMessageHandler::OnServiceResolved(
    uint64 id,
    ServiceResolver::RequestStatus status,
    const ServiceDescription& description) {
  VLOG(1) << "OnServiceResolved, id=" << id
          << ", status=" << ResolverStatusToString(status)
          << ", name=" << description.service_name;

  DCHECK(service_discovery_client_);
  Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description));
}

void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
    uint64 id,
    bool success,
    const net::IPAddressNumber& address_ipv4,
    const net::IPAddressNumber& address_ipv6) {
  VLOG(1) << "OnLocalDomainResolved, id=" << id
          << ", IPv4=" << (address_ipv4.empty() ? "" :
                           net::IPAddressToString(address_ipv4))
          << ", IPv6=" << (address_ipv6.empty() ? "" :
                           net::IPAddressToString(address_ipv6));

  DCHECK(service_discovery_client_);
  Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
          id, success, address_ipv4, address_ipv6));
}

void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) {
  utility_task_runner_->PostTask(FROM_HERE,
                                 base::Bind(&SendHostMessageOnUtilityThread,
                                            msg));
}

}  // namespace local_discovery

/* [<][>][^][v][top][bottom][index][help] */