// Copyright 2016 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. //#include "DispatcherBase.h" //#include "Parser.h" {% for namespace in config.protocol.namespace %} namespace {{namespace}} { {% endfor %} // static DispatchResponse DispatchResponse::OK() { DispatchResponse result; result.m_status = kSuccess; result.m_errorCode = kParseError; return result; } // static DispatchResponse DispatchResponse::Error(const String& error) { DispatchResponse result; result.m_status = kError; result.m_errorCode = kServerError; result.m_errorMessage = error; return result; } // static DispatchResponse DispatchResponse::InternalError() { DispatchResponse result; result.m_status = kError; result.m_errorCode = kInternalError; result.m_errorMessage = "Internal error"; return result; } // static DispatchResponse DispatchResponse::InvalidParams(const String& error) { DispatchResponse result; result.m_status = kError; result.m_errorCode = kInvalidParams; result.m_errorMessage = error; return result; } // static DispatchResponse DispatchResponse::FallThrough() { DispatchResponse result; result.m_status = kFallThrough; result.m_errorCode = kParseError; return result; } // static const char DispatcherBase::kInvalidParamsString[] = "Invalid parameters"; DispatcherBase::WeakPtr::WeakPtr(DispatcherBase* dispatcher) : m_dispatcher(dispatcher) { } DispatcherBase::WeakPtr::~WeakPtr() { if (m_dispatcher) m_dispatcher->m_weakPtrs.erase(this); } DispatcherBase::Callback::Callback(std::unique_ptr<DispatcherBase::WeakPtr> backendImpl, int callId, int callbackId) : m_backendImpl(std::move(backendImpl)) , m_callId(callId) , m_callbackId(callbackId) { } DispatcherBase::Callback::~Callback() = default; void DispatcherBase::Callback::dispose() { m_backendImpl = nullptr; } void DispatcherBase::Callback::sendIfActive(std::unique_ptr<protocol::DictionaryValue> partialMessage, const DispatchResponse& response) { if (!m_backendImpl || !m_backendImpl->get()) return; m_backendImpl->get()->sendResponse(m_callId, response, std::move(partialMessage)); m_backendImpl = nullptr; } void DispatcherBase::Callback::fallThroughIfActive() { if (!m_backendImpl || !m_backendImpl->get()) return; m_backendImpl->get()->markFallThrough(m_callbackId); m_backendImpl = nullptr; } DispatcherBase::DispatcherBase(FrontendChannel* frontendChannel) : m_frontendChannel(frontendChannel) , m_lastCallbackId(0) , m_lastCallbackFallThrough(false) { } DispatcherBase::~DispatcherBase() { clearFrontend(); } int DispatcherBase::nextCallbackId() { m_lastCallbackFallThrough = false; return ++m_lastCallbackId; } void DispatcherBase::markFallThrough(int callbackId) { DCHECK(callbackId == m_lastCallbackId); m_lastCallbackFallThrough = true; } void DispatcherBase::sendResponse(int callId, const DispatchResponse& response, std::unique_ptr<protocol::DictionaryValue> result) { if (!m_frontendChannel) return; if (response.status() == DispatchResponse::kError) { reportProtocolError(callId, response.errorCode(), response.errorMessage(), nullptr); return; } m_frontendChannel->sendProtocolResponse(callId, InternalResponse::createResponse(callId, std::move(result))); } void DispatcherBase::sendResponse(int callId, const DispatchResponse& response) { sendResponse(callId, response, DictionaryValue::create()); } namespace { class ProtocolError : public Serializable { public: static std::unique_ptr<ProtocolError> createErrorResponse(int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { std::unique_ptr<ProtocolError> protocolError(new ProtocolError(code, errorMessage)); protocolError->m_callId = callId; protocolError->m_hasCallId = true; if (errors && errors->hasErrors()) protocolError->m_data = errors->errors(); return protocolError; } static std::unique_ptr<ProtocolError> createErrorNotification(DispatchResponse::ErrorCode code, const String& errorMessage) { return std::unique_ptr<ProtocolError>(new ProtocolError(code, errorMessage)); } String serialize() override { std::unique_ptr<protocol::DictionaryValue> error = DictionaryValue::create(); error->setInteger("code", m_code); error->setString("message", m_errorMessage); if (m_data.length()) error->setString("data", m_data); std::unique_ptr<protocol::DictionaryValue> message = DictionaryValue::create(); message->setObject("error", std::move(error)); if (m_hasCallId) message->setInteger("id", m_callId); return message->serialize(); } ~ProtocolError() override {} private: ProtocolError(DispatchResponse::ErrorCode code, const String& errorMessage) : m_code(code) , m_errorMessage(errorMessage) { } DispatchResponse::ErrorCode m_code; String m_errorMessage; String m_data; int m_callId = 0; bool m_hasCallId = false; }; } // namespace static void reportProtocolErrorTo(FrontendChannel* frontendChannel, int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { if (frontendChannel) frontendChannel->sendProtocolResponse(callId, ProtocolError::createErrorResponse(callId, code, errorMessage, errors)); } static void reportProtocolErrorTo(FrontendChannel* frontendChannel, DispatchResponse::ErrorCode code, const String& errorMessage) { if (frontendChannel) frontendChannel->sendProtocolNotification(ProtocolError::createErrorNotification(code, errorMessage)); } void DispatcherBase::reportProtocolError(int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { reportProtocolErrorTo(m_frontendChannel, callId, code, errorMessage, errors); } void DispatcherBase::clearFrontend() { m_frontendChannel = nullptr; for (auto& weak : m_weakPtrs) weak->dispose(); m_weakPtrs.clear(); } std::unique_ptr<DispatcherBase::WeakPtr> DispatcherBase::weakPtr() { std::unique_ptr<DispatcherBase::WeakPtr> weak(new DispatcherBase::WeakPtr(this)); m_weakPtrs.insert(weak.get()); return weak; } UberDispatcher::UberDispatcher(FrontendChannel* frontendChannel) : m_frontendChannel(frontendChannel) , m_fallThroughForNotFound(false) { } void UberDispatcher::setFallThroughForNotFound(bool fallThroughForNotFound) { m_fallThroughForNotFound = fallThroughForNotFound; } void UberDispatcher::registerBackend(const String& name, std::unique_ptr<protocol::DispatcherBase> dispatcher) { m_dispatchers[name] = std::move(dispatcher); } void UberDispatcher::setupRedirects(const HashMap<String, String>& redirects) { for (const auto& pair : redirects) m_redirects[pair.first] = pair.second; } DispatchResponse::Status UberDispatcher::dispatch(std::unique_ptr<Value> parsedMessage, int* outCallId, String* outMethod) { if (!parsedMessage) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kParseError, "Message must be a valid JSON"); return DispatchResponse::kError; } std::unique_ptr<protocol::DictionaryValue> messageObject = DictionaryValue::cast(std::move(parsedMessage)); if (!messageObject) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must be an object"); return DispatchResponse::kError; } int callId = 0; protocol::Value* callIdValue = messageObject->get("id"); bool success = callIdValue && callIdValue->asInteger(&callId); if (outCallId) *outCallId = callId; if (!success) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must have integer 'id' property"); return DispatchResponse::kError; } protocol::Value* methodValue = messageObject->get("method"); String method; success = methodValue && methodValue->asString(&method); if (outMethod) *outMethod = method; if (!success) { reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kInvalidRequest, "Message must have string 'method' property", nullptr); return DispatchResponse::kError; } HashMap<String, String>::iterator redirectIt = m_redirects.find(method); if (redirectIt != m_redirects.end()) method = redirectIt->second; size_t dotIndex = StringUtil::find(method, "."); if (dotIndex == StringUtil::kNotFound) { if (m_fallThroughForNotFound) return DispatchResponse::kFallThrough; reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kMethodNotFound, "'" + method + "' wasn't found", nullptr); return DispatchResponse::kError; } String domain = StringUtil::substring(method, 0, dotIndex); auto it = m_dispatchers.find(domain); if (it == m_dispatchers.end()) { if (m_fallThroughForNotFound) return DispatchResponse::kFallThrough; reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kMethodNotFound, "'" + method + "' wasn't found", nullptr); return DispatchResponse::kError; } return it->second->dispatch(callId, method, std::move(messageObject)); } bool UberDispatcher::getCommandName(const String& message, String* method, std::unique_ptr<protocol::DictionaryValue>* parsedMessage) { std::unique_ptr<protocol::Value> value = StringUtil::parseJSON(message); if (!value) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kParseError, "Message must be a valid JSON"); return false; } protocol::DictionaryValue* object = DictionaryValue::cast(value.get()); if (!object) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must be an object"); return false; } if (!object->getString("method", method)) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must have string 'method' property"); return false; } parsedMessage->reset(DictionaryValue::cast(value.release())); return true; } UberDispatcher::~UberDispatcher() = default; // static std::unique_ptr<InternalResponse> InternalResponse::createResponse(int callId, std::unique_ptr<Serializable> params) { return std::unique_ptr<InternalResponse>(new InternalResponse(callId, String(), std::move(params))); } // static std::unique_ptr<InternalResponse> InternalResponse::createNotification(const String& notification, std::unique_ptr<Serializable> params) { return std::unique_ptr<InternalResponse>(new InternalResponse(0, notification, std::move(params))); } String InternalResponse::serialize() { std::unique_ptr<DictionaryValue> result = DictionaryValue::create(); std::unique_ptr<Serializable> params(m_params ? std::move(m_params) : DictionaryValue::create()); if (m_notification.length()) { result->setString("method", m_notification); result->setValue("params", SerializedValue::create(params->serialize())); } else { result->setInteger("id", m_callId); result->setValue("result", SerializedValue::create(params->serialize())); } return result->serialize(); } InternalResponse::InternalResponse(int callId, const String& notification, std::unique_ptr<Serializable> params) : m_callId(callId) , m_notification(notification) , m_params(params ? std::move(params) : nullptr) { } {% for namespace in config.protocol.namespace %} } // namespace {{namespace}} {% endfor %}