root/net/websockets/websocket_handshake_handler.cc

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. ParseHandshakeHeader
  2. FetchHeaders
  3. GetHeaderName
  4. FilterHeaders
  5. CheckVersionInRequest
  6. AppendHeader
  7. raw_length_
  8. ParseRequest
  9. original_length
  10. AppendHeaderIfMissing
  11. RemoveHeaders
  12. GetRequestInfo
  13. GetRequestHeaderBlock
  14. GetRawRequest
  15. raw_length
  16. ParseRawResponse
  17. HasResponse
  18. ComputeSecWebSocketAccept
  19. ParseResponseInfo
  20. ParseResponseHeaderBlock
  21. GetHeaders
  22. RemoveHeaders
  23. GetRawResponse
  24. GetResponse

// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/websockets/websocket_handshake_handler.h"

#include <limits>

#include "base/base64.h"
#include "base/sha1.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
#include "base/strings/string_tokenizer.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "url/gurl.h"

namespace net {
namespace {

const int kVersionHeaderValueForRFC6455 = 13;

// Splits |handshake_message| into Status-Line or Request-Line (including CRLF)
// and headers (excluding 2nd CRLF of double CRLFs at the end of a handshake
// response).
void ParseHandshakeHeader(
    const char* handshake_message, int len,
    std::string* request_line,
    std::string* headers) {
  size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
  if (i == base::StringPiece::npos) {
    *request_line = std::string(handshake_message, len);
    *headers = "";
    return;
  }
  // |request_line| includes \r\n.
  *request_line = std::string(handshake_message, i + 2);

  int header_len = len - (i + 2) - 2;
  if (header_len > 0) {
    // |handshake_message| includes trailing \r\n\r\n.
    // |headers| doesn't include 2nd \r\n.
    *headers = std::string(handshake_message + i + 2, header_len);
  } else {
    *headers = "";
  }
}

void FetchHeaders(const std::string& headers,
                  const char* const headers_to_get[],
                  size_t headers_to_get_len,
                  std::vector<std::string>* values) {
  net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
  while (iter.GetNext()) {
    for (size_t i = 0; i < headers_to_get_len; i++) {
      if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
                               headers_to_get[i])) {
        values->push_back(iter.values());
      }
    }
  }
}

bool GetHeaderName(std::string::const_iterator line_begin,
                   std::string::const_iterator line_end,
                   std::string::const_iterator* name_begin,
                   std::string::const_iterator* name_end) {
  std::string::const_iterator colon = std::find(line_begin, line_end, ':');
  if (colon == line_end) {
    return false;
  }
  *name_begin = line_begin;
  *name_end = colon;
  if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
    return false;
  net::HttpUtil::TrimLWS(name_begin, name_end);
  return true;
}

// Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
// is, lines that are not formatted as "<name>: <value>\r\n".
std::string FilterHeaders(
    const std::string& headers,
    const char* const headers_to_remove[],
    size_t headers_to_remove_len) {
  std::string filtered_headers;

  base::StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
  while (lines.GetNext()) {
    std::string::const_iterator line_begin = lines.token_begin();
    std::string::const_iterator line_end = lines.token_end();
    std::string::const_iterator name_begin;
    std::string::const_iterator name_end;
    bool should_remove = false;
    if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
      for (size_t i = 0; i < headers_to_remove_len; ++i) {
        if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
          should_remove = true;
          break;
        }
      }
    }
    if (!should_remove) {
      filtered_headers.append(line_begin, line_end);
      filtered_headers.append("\r\n");
    }
  }
  return filtered_headers;
}

bool CheckVersionInRequest(const std::string& request_headers) {
  std::vector<std::string> values;
  const char* const headers_to_get[1] = {
    websockets::kSecWebSocketVersionLowercase};
  FetchHeaders(request_headers, headers_to_get, 1, &values);
  DCHECK_LE(values.size(), 1U);
  if (values.empty())
    return false;

  int version;
  bool conversion_success = base::StringToInt(values[0], &version);
  if (!conversion_success)
    return false;

  return version == kVersionHeaderValueForRFC6455;
}

// Append a header to a string. Equivalent to
//   response_message += header + ": " + value + "\r\n"
// but avoids unnecessary allocations and copies.
void AppendHeader(const base::StringPiece& header,
                  const base::StringPiece& value,
                  std::string* response_message) {
  static const char kColonSpace[] = ": ";
  const size_t kColonSpaceSize = sizeof(kColonSpace) - 1;
  static const char kCrNl[] = "\r\n";
  const size_t kCrNlSize = sizeof(kCrNl) - 1;

  size_t extra_size =
      header.size() + kColonSpaceSize + value.size() + kCrNlSize;
  response_message->reserve(response_message->size() + extra_size);
  response_message->append(header.begin(), header.end());
  response_message->append(kColonSpace, kColonSpace + kColonSpaceSize);
  response_message->append(value.begin(), value.end());
  response_message->append(kCrNl, kCrNl + kCrNlSize);
}

}  // namespace

WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
    : original_length_(0),
      raw_length_(0) {}

bool WebSocketHandshakeRequestHandler::ParseRequest(
    const char* data, int length) {
  DCHECK_GT(length, 0);
  std::string input(data, length);
  int input_header_length =
      HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
  if (input_header_length <= 0)
    return false;

  ParseHandshakeHeader(input.data(),
                       input_header_length,
                       &request_line_,
                       &headers_);

  if (!CheckVersionInRequest(headers_)) {
    NOTREACHED();
    return false;
  }

  original_length_ = input_header_length;
  return true;
}

size_t WebSocketHandshakeRequestHandler::original_length() const {
  return original_length_;
}

void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
    const std::string& name, const std::string& value) {
  DCHECK(!headers_.empty());
  HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
}

void WebSocketHandshakeRequestHandler::RemoveHeaders(
    const char* const headers_to_remove[],
    size_t headers_to_remove_len) {
  DCHECK(!headers_.empty());
  headers_ = FilterHeaders(
      headers_, headers_to_remove, headers_to_remove_len);
}

HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
    const GURL& url, std::string* challenge) {
  HttpRequestInfo request_info;
  request_info.url = url;
  size_t method_end = base::StringPiece(request_line_).find_first_of(" ");
  if (method_end != base::StringPiece::npos)
    request_info.method = std::string(request_line_.data(), method_end);

  request_info.extra_headers.Clear();
  request_info.extra_headers.AddHeadersFromString(headers_);

  request_info.extra_headers.RemoveHeader(websockets::kUpgrade);
  request_info.extra_headers.RemoveHeader(HttpRequestHeaders::kConnection);

  std::string key;
  bool header_present = request_info.extra_headers.GetHeader(
      websockets::kSecWebSocketKey, &key);
  DCHECK(header_present);
  request_info.extra_headers.RemoveHeader(websockets::kSecWebSocketKey);
  *challenge = key;
  return request_info;
}

bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
    const GURL& url,
    SpdyHeaderBlock* headers,
    std::string* challenge,
    int spdy_protocol_version) {
  // Construct opening handshake request headers as a SPDY header block.
  // For details, see WebSocket Layering over SPDY/3 Draft 8.
  if (spdy_protocol_version <= 2) {
    (*headers)["path"] = url.path();
    (*headers)["version"] = "WebSocket/13";
    (*headers)["scheme"] = url.scheme();
  } else {
    (*headers)[":path"] = url.path();
    (*headers)[":version"] = "WebSocket/13";
    (*headers)[":scheme"] = url.scheme();
  }

  HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
  while (iter.GetNext()) {
    if (LowerCaseEqualsASCII(iter.name_begin(),
                             iter.name_end(),
                             websockets::kUpgradeLowercase) ||
        LowerCaseEqualsASCII(
            iter.name_begin(), iter.name_end(), "connection") ||
        LowerCaseEqualsASCII(iter.name_begin(),
                             iter.name_end(),
                             websockets::kSecWebSocketVersionLowercase)) {
      // These headers must be ignored.
      continue;
    } else if (LowerCaseEqualsASCII(iter.name_begin(),
                                    iter.name_end(),
                                    websockets::kSecWebSocketKeyLowercase)) {
      *challenge = iter.values();
      // Sec-WebSocket-Key is not sent to the server.
      continue;
    } else if (LowerCaseEqualsASCII(
                   iter.name_begin(), iter.name_end(), "host") ||
               LowerCaseEqualsASCII(
                   iter.name_begin(), iter.name_end(), "origin") ||
               LowerCaseEqualsASCII(
                   iter.name_begin(),
                   iter.name_end(),
                   websockets::kSecWebSocketProtocolLowercase) ||
               LowerCaseEqualsASCII(
                   iter.name_begin(),
                   iter.name_end(),
                   websockets::kSecWebSocketExtensionsLowercase)) {
      // TODO(toyoshim): Some WebSocket extensions may not be compatible with
      // SPDY. We should omit them from a Sec-WebSocket-Extension header.
      std::string name;
      if (spdy_protocol_version <= 2)
        name = StringToLowerASCII(iter.name());
      else
        name = ":" + StringToLowerASCII(iter.name());
      (*headers)[name] = iter.values();
      continue;
    }
    // Others should be sent out to |headers|.
    std::string name = StringToLowerASCII(iter.name());
    SpdyHeaderBlock::iterator found = headers->find(name);
    if (found == headers->end()) {
      (*headers)[name] = iter.values();
    } else {
      // For now, websocket doesn't use multiple headers, but follows to http.
      found->second.append(1, '\0');  // +=() doesn't append 0's
      found->second.append(iter.values());
    }
  }

  return true;
}

std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
  DCHECK(!request_line_.empty());
  DCHECK(!headers_.empty());

  std::string raw_request = request_line_ + headers_ + "\r\n";
  raw_length_ = raw_request.size();
  return raw_request;
}

size_t WebSocketHandshakeRequestHandler::raw_length() const {
  DCHECK_GT(raw_length_, 0);
  return raw_length_;
}

WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
    : original_header_length_(0) {}

WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}

size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
    const char* data, int length) {
  DCHECK_GT(length, 0);
  if (HasResponse()) {
    DCHECK(!status_line_.empty());
    // headers_ might be empty for wrong response from server.

    return 0;
  }

  size_t old_original_length = original_.size();

  original_.append(data, length);
  // TODO(ukai): fail fast when response gives wrong status code.
  original_header_length_ = HttpUtil::LocateEndOfHeaders(
      original_.data(), original_.size(), 0);
  if (!HasResponse())
    return length;

  ParseHandshakeHeader(original_.data(),
                       original_header_length_,
                       &status_line_,
                       &headers_);
  int header_size = status_line_.size() + headers_.size();
  DCHECK_GE(original_header_length_, header_size);
  header_separator_ = std::string(original_.data() + header_size,
                                  original_header_length_ - header_size);
  return original_header_length_ - old_original_length;
}

bool WebSocketHandshakeResponseHandler::HasResponse() const {
  return original_header_length_ > 0 &&
      static_cast<size_t>(original_header_length_) <= original_.size();
}

void ComputeSecWebSocketAccept(const std::string& key,
                               std::string* accept) {
  DCHECK(accept);

  std::string hash =
      base::SHA1HashString(key + websockets::kWebSocketGuid);
  base::Base64Encode(hash, accept);
}

bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
    const HttpResponseInfo& response_info,
    const std::string& challenge) {
  if (!response_info.headers.get())
    return false;

  // TODO(ricea): Eliminate all the reallocations and string copies.
  std::string response_message;
  response_message = response_info.headers->GetStatusLine();
  response_message += "\r\n";

  AppendHeader(websockets::kUpgrade,
               websockets::kWebSocketLowercase,
               &response_message);

  AppendHeader(
      HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);

  std::string websocket_accept;
  ComputeSecWebSocketAccept(challenge, &websocket_accept);
  AppendHeader(
      websockets::kSecWebSocketAccept, websocket_accept, &response_message);

  void* iter = NULL;
  std::string name;
  std::string value;
  while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
    AppendHeader(name, value, &response_message);
  }
  response_message += "\r\n";

  return ParseRawResponse(response_message.data(),
                          response_message.size()) == response_message.size();
}

bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
    const SpdyHeaderBlock& headers,
    const std::string& challenge,
    int spdy_protocol_version) {
  SpdyHeaderBlock::const_iterator status;
  if (spdy_protocol_version <= 2)
    status = headers.find("status");
  else
    status = headers.find(":status");
  if (status == headers.end())
    return false;

  std::string hash =
      base::SHA1HashString(challenge + websockets::kWebSocketGuid);
  std::string websocket_accept;
  base::Base64Encode(hash, &websocket_accept);

  std::string response_message = base::StringPrintf(
      "%s %s\r\n", websockets::kHttpProtocolVersion, status->second.c_str());

  AppendHeader(
      websockets::kUpgrade, websockets::kWebSocketLowercase, &response_message);
  AppendHeader(
      HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);
  AppendHeader(
      websockets::kSecWebSocketAccept, websocket_accept, &response_message);

  for (SpdyHeaderBlock::const_iterator iter = headers.begin();
       iter != headers.end();
       ++iter) {
    // For each value, if the server sends a NUL-separated list of values,
    // we separate that back out into individual headers for each value
    // in the list.
    if ((spdy_protocol_version <= 2 &&
         LowerCaseEqualsASCII(iter->first, "status")) ||
        (spdy_protocol_version >= 3 &&
         LowerCaseEqualsASCII(iter->first, ":status"))) {
      // The status value is already handled as the first line of
      // |response_message|. Just skip here.
      continue;
    }
    const std::string& value = iter->second;
    size_t start = 0;
    size_t end = 0;
    do {
      end = value.find('\0', start);
      std::string tval;
      if (end != std::string::npos)
        tval = value.substr(start, (end - start));
      else
        tval = value.substr(start);
      if (spdy_protocol_version >= 3 &&
          (LowerCaseEqualsASCII(iter->first,
                                websockets::kSecWebSocketProtocolSpdy3) ||
           LowerCaseEqualsASCII(iter->first,
                                websockets::kSecWebSocketExtensionsSpdy3)))
        AppendHeader(iter->first.substr(1), tval, &response_message);
      else
        AppendHeader(iter->first, tval, &response_message);
      start = end + 1;
    } while (end != std::string::npos);
  }
  response_message += "\r\n";

  return ParseRawResponse(response_message.data(),
                          response_message.size()) == response_message.size();
}

void WebSocketHandshakeResponseHandler::GetHeaders(
    const char* const headers_to_get[],
    size_t headers_to_get_len,
    std::vector<std::string>* values) {
  DCHECK(HasResponse());
  DCHECK(!status_line_.empty());
  // headers_ might be empty for wrong response from server.
  if (headers_.empty())
    return;

  FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
}

void WebSocketHandshakeResponseHandler::RemoveHeaders(
    const char* const headers_to_remove[],
    size_t headers_to_remove_len) {
  DCHECK(HasResponse());
  DCHECK(!status_line_.empty());
  // headers_ might be empty for wrong response from server.
  if (headers_.empty())
    return;

  headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
}

std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
  DCHECK(HasResponse());
  return original_.substr(0, original_header_length_);
}

std::string WebSocketHandshakeResponseHandler::GetResponse() {
  DCHECK(HasResponse());
  DCHECK(!status_line_.empty());
  // headers_ might be empty for wrong response from server.

  return status_line_ + headers_ + header_separator_;
}

}  // namespace net

/* [<][>][^][v][top][bottom][index][help] */