Bug 1286986: Add glue to enable process-local registration of COM proxies at runtime; r?jimm
MozReview-Commit-ID: 7VTCPQa90Vv
new file mode 100644
--- /dev/null
+++ b/ipc/mscom/Registration.cpp
@@ -0,0 +1,332 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+// COM registration data structures are built with C code, so we need to
+// simulate that in our C++ code by defining CINTERFACE before including
+// anything else that could possibly pull in Windows header files.
+#define CINTERFACE
+
+#include "mozilla/mscom/Registration.h"
+
+#include "mozilla/ArrayUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/ClearOnShutdown.h"
+#include "mozilla/Move.h"
+#include "mozilla/Mutex.h"
+#include "mozilla/Pair.h"
+#include "mozilla/StaticPtr.h"
+#include "nsTArray.h"
+
+#include <oaidl.h>
+#include <objidl.h>
+#include <rpcproxy.h>
+#include <shlwapi.h>
+
+/* This code MUST NOT use any non-inlined internal Mozilla APIs, as it will be
+ compiled into DLLs that COM may load into non-Mozilla processes! */
+
+namespace {
+
+// This function is defined in generated code for proxy DLLs but is not declared
+// in rpcproxy.h, so we need this typedef.
+typedef void (RPC_ENTRY *GetProxyDllInfoFnPtr)(const ProxyFileInfo*** aInfo,
+ const CLSID** aId);
+
+} // anonymous namespace
+
+namespace mozilla {
+namespace mscom {
+
+static bool
+BuildLibPath(RegistrationFlags aFlags, wchar_t* aBuffer, size_t aBufferLen,
+ const wchar_t* aLeafName)
+{
+ if (aFlags == RegistrationFlags::eUseBinDirectory) {
+ HMODULE thisModule = nullptr;
+ if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
+ GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
+ reinterpret_cast<LPCTSTR>(&RegisterProxy),
+ &thisModule)) {
+ return false;
+ }
+ DWORD fileNameResult = GetModuleFileName(thisModule, aBuffer, aBufferLen);
+ if (!fileNameResult || (fileNameResult == aBufferLen &&
+ ::GetLastError() == ERROR_INSUFFICIENT_BUFFER)) {
+ return false;
+ }
+ if (!PathRemoveFileSpec(aBuffer)) {
+ return false;
+ }
+ } else if (aFlags == RegistrationFlags::eUseSystemDirectory) {
+ UINT result = GetSystemDirectoryW(aBuffer, static_cast<UINT>(aBufferLen));
+ if (!result || result > aBufferLen) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+
+ if (!PathAppend(aBuffer, aLeafName)) {
+ return false;
+ }
+ return true;
+}
+
+UniquePtr<RegisteredProxy>
+RegisterProxy(const wchar_t* aLeafName, RegistrationFlags aFlags)
+{
+ wchar_t modulePathBuf[MAX_PATH + 1] = {0};
+ if (!BuildLibPath(aFlags, modulePathBuf, ArrayLength(modulePathBuf),
+ aLeafName)) {
+ return nullptr;
+ }
+
+ HMODULE proxyDll = LoadLibrary(modulePathBuf);
+ if (!proxyDll) {
+ return nullptr;
+ }
+
+ auto DllGetClassObjectFn = reinterpret_cast<LPFNGETCLASSOBJECT>(
+ GetProcAddress(proxyDll, "DllGetClassObject"));
+ if (!DllGetClassObjectFn) {
+ FreeLibrary(proxyDll);
+ return nullptr;
+ }
+
+ auto GetProxyDllInfoFn = reinterpret_cast<GetProxyDllInfoFnPtr>(
+ GetProcAddress(proxyDll, "GetProxyDllInfo"));
+ if (!GetProxyDllInfoFn) {
+ FreeLibrary(proxyDll);
+ return nullptr;
+ }
+
+ const ProxyFileInfo** proxyInfo = nullptr;
+ const CLSID* proxyClsid = nullptr;
+ GetProxyDllInfoFn(&proxyInfo, &proxyClsid);
+ if (!proxyInfo || !proxyClsid) {
+ FreeLibrary(proxyDll);
+ return nullptr;
+ }
+
+ IUnknown* classObject = nullptr;
+ HRESULT hr = DllGetClassObjectFn(*proxyClsid, IID_IUnknown,
+ (void**) &classObject);
+ if (FAILED(hr)) {
+ FreeLibrary(proxyDll);
+ return nullptr;
+ }
+
+ DWORD regCookie;
+ hr = CoRegisterClassObject(*proxyClsid, classObject, CLSCTX_INPROC_SERVER,
+ REGCLS_MULTIPLEUSE, ®Cookie);
+ if (FAILED(hr)) {
+ classObject->lpVtbl->Release(classObject);
+ FreeLibrary(proxyDll);
+ return nullptr;
+ }
+
+ ITypeLib* typeLib = nullptr;
+ hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ CoRevokeClassObject(regCookie);
+ classObject->lpVtbl->Release(classObject);
+ FreeLibrary(proxyDll);
+ return nullptr;
+ }
+
+ // RegisteredProxy takes ownership of proxyDll, classObject, and typeLib
+ // references
+ auto result(MakeUnique<RegisteredProxy>(reinterpret_cast<uintptr_t>(proxyDll),
+ classObject, regCookie, typeLib));
+
+ while (*proxyInfo) {
+ const ProxyFileInfo& curInfo = **proxyInfo;
+ for (unsigned short i = 0, e = curInfo.TableSize; i < e; ++i) {
+ hr = CoRegisterPSClsid(*(curInfo.pStubVtblList[i]->header.piid),
+ *proxyClsid);
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+ }
+ ++proxyInfo;
+ }
+
+ return result;
+}
+
+UniquePtr<RegisteredProxy>
+RegisterTypelib(const wchar_t* aLeafName, RegistrationFlags aFlags)
+{
+ wchar_t modulePathBuf[MAX_PATH + 1] = {0};
+ if (!BuildLibPath(aFlags, modulePathBuf, ArrayLength(modulePathBuf),
+ aLeafName)) {
+ return nullptr;
+ }
+
+ ITypeLib* typeLib = nullptr;
+ HRESULT hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib);
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+
+ // RegisteredProxy takes ownership of typeLib reference
+ auto result(MakeUnique<RegisteredProxy>(typeLib));
+ return result;
+}
+
+RegisteredProxy::RegisteredProxy(uintptr_t aModule, IUnknown* aClassObject,
+ uint32_t aRegCookie, ITypeLib* aTypeLib)
+ : mModule(aModule)
+ , mClassObject(aClassObject)
+ , mRegCookie(aRegCookie)
+ , mTypeLib(aTypeLib)
+{
+ MOZ_ASSERT(aClassObject);
+ MOZ_ASSERT(aTypeLib);
+ AddToRegistry(this);
+}
+
+RegisteredProxy::RegisteredProxy(ITypeLib* aTypeLib)
+ : mModule(0)
+ , mClassObject(nullptr)
+ , mRegCookie(0)
+ , mTypeLib(aTypeLib)
+{
+ MOZ_ASSERT(aTypeLib);
+ AddToRegistry(this);
+}
+
+RegisteredProxy::~RegisteredProxy()
+{
+ DeleteFromRegistry(this);
+ if (mTypeLib) {
+ mTypeLib->lpVtbl->Release(mTypeLib);
+ }
+ if (mClassObject) {
+ ::CoRevokeClassObject(mRegCookie);
+ mClassObject->lpVtbl->Release(mClassObject);
+ }
+ if (mModule) {
+ ::FreeLibrary(reinterpret_cast<HMODULE>(mModule));
+ }
+}
+
+RegisteredProxy::RegisteredProxy(RegisteredProxy&& aOther)
+{
+ *this = mozilla::Forward<RegisteredProxy>(aOther);
+}
+
+RegisteredProxy&
+RegisteredProxy::operator=(RegisteredProxy&& aOther)
+{
+ mModule = aOther.mModule;
+ aOther.mModule = 0;
+ mClassObject = aOther.mClassObject;
+ aOther.mClassObject = nullptr;
+ mRegCookie = aOther.mRegCookie;
+ aOther.mRegCookie = 0;
+ mTypeLib = aOther.mTypeLib;
+ aOther.mTypeLib = nullptr;
+ return *this;
+}
+
+HRESULT
+RegisteredProxy::GetTypeInfoForInterface(REFIID aIid,
+ ITypeInfo** aOutTypeInfo) const
+{
+ if (!aOutTypeInfo) {
+ return E_INVALIDARG;
+ }
+ if (!mTypeLib) {
+ return E_UNEXPECTED;
+ }
+ return mTypeLib->lpVtbl->GetTypeInfoOfGuid(mTypeLib, aIid, aOutTypeInfo);
+}
+
+static StaticAutoPtr<nsTArray<RegisteredProxy*>> sRegistry;
+static StaticAutoPtr<Mutex> sRegMutex;
+static StaticAutoPtr<nsTArray<Pair<const ArrayData*, size_t>>> sArrayData;
+
+static Mutex&
+GetMutex()
+{
+ static Mutex& mutex = []() -> Mutex& {
+ if (!sRegMutex) {
+ sRegMutex = new Mutex("RegisteredProxy::sRegMutex");
+ ClearOnShutdown(&sRegMutex, ShutdownPhase::ShutdownThreads);
+ }
+ return *sRegMutex;
+ }();
+ return mutex;
+}
+
+/* static */ bool
+RegisteredProxy::Find(REFIID aIid, ITypeInfo** aTypeInfo)
+{
+ MutexAutoLock lock(GetMutex());
+ nsTArray<RegisteredProxy*>& registry = *sRegistry;
+ for (uint32_t idx = 0, len = registry.Length(); idx < len; ++idx) {
+ if (SUCCEEDED(registry[idx]->GetTypeInfoForInterface(aIid, aTypeInfo))) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/* static */ void
+RegisteredProxy::AddToRegistry(RegisteredProxy* aProxy)
+{
+ MutexAutoLock lock(GetMutex());
+ if (!sRegistry) {
+ sRegistry = new nsTArray<RegisteredProxy*>();
+ ClearOnShutdown(&sRegistry);
+ }
+ sRegistry->AppendElement(aProxy);
+}
+
+/* static */ void
+RegisteredProxy::DeleteFromRegistry(RegisteredProxy* aProxy)
+{
+ MutexAutoLock lock(GetMutex());
+ sRegistry->RemoveElement(aProxy);
+}
+
+void
+RegisterArrayData(const ArrayData* aArrayData, size_t aLength)
+{
+ MutexAutoLock lock(GetMutex());
+ if (!sArrayData) {
+ sArrayData = new nsTArray<Pair<const ArrayData*, size_t>>();
+ ClearOnShutdown(&sArrayData, ShutdownPhase::ShutdownThreads);
+ }
+ sArrayData->AppendElement(MakePair(aArrayData, aLength));
+}
+
+const ArrayData*
+FindArrayData(REFIID aIid, ULONG aMethodIndex)
+{
+ MutexAutoLock lock(GetMutex());
+ if (!sArrayData) {
+ return nullptr;
+ }
+ for (uint32_t outerIdx = 0, outerLen = sArrayData->Length();
+ outerIdx < outerLen; ++outerIdx) {
+ auto& data = sArrayData->ElementAt(outerIdx);
+ for (size_t innerIdx = 0, innerLen = data.second(); innerIdx < innerLen;
+ ++innerIdx) {
+ const ArrayData* array = data.first();
+ if (aIid == array[innerIdx].mIid &&
+ aMethodIndex == array[innerIdx].mMethodIndex) {
+ return &array[innerIdx];
+ }
+ }
+ }
+ return nullptr;
+}
+
+} // namespace mscom
+} // namespace mozilla
new file mode 100644
--- /dev/null
+++ b/ipc/mscom/Registration.h
@@ -0,0 +1,132 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Registration_h
+#define mozilla_mscom_Registration_h
+
+#include "mozilla/RefPtr.h"
+#include "mozilla/UniquePtr.h"
+
+#include <objbase.h>
+
+struct ITypeInfo;
+struct ITypeLib;
+
+namespace mozilla {
+namespace mscom {
+
+/**
+ * Assumptions:
+ * (1) The DLL exports GetProxyDllInfo. This is not exported by default; it must
+ * be specified in the EXPORTS section of the DLL's module definition file.
+ */
+class RegisteredProxy
+{
+public:
+ RegisteredProxy(uintptr_t aModule, IUnknown* aClassObject,
+ uint32_t aRegCookie, ITypeLib* aTypeLib);
+ explicit RegisteredProxy(ITypeLib* aTypeLib);
+ RegisteredProxy(RegisteredProxy&& aOther);
+ RegisteredProxy& operator=(RegisteredProxy&& aOther);
+
+ ~RegisteredProxy();
+
+ HRESULT GetTypeInfoForInterface(REFIID aIid, ITypeInfo** aOutTypeInfo) const;
+
+ static bool Find(REFIID aIid, ITypeInfo** aOutTypeInfo);
+
+private:
+ RegisteredProxy() = delete;
+ RegisteredProxy(RegisteredProxy&) = delete;
+ RegisteredProxy& operator=(RegisteredProxy&) = delete;
+
+ static void AddToRegistry(RegisteredProxy* aProxy);
+ static void DeleteFromRegistry(RegisteredProxy* aProxy);
+
+private:
+ // Not using Windows types here: We shouldn't #include windows.h
+ // since it might pull in COM code which we want to do very carefully in
+ // Registration.cpp.
+ uintptr_t mModule;
+ IUnknown* mClassObject;
+ uint32_t mRegCookie;
+ ITypeLib* mTypeLib;
+};
+
+enum class RegistrationFlags
+{
+ eUseBinDirectory,
+ eUseSystemDirectory
+};
+
+// For DLL files. Assumes corresponding TLB is embedded in resources.
+UniquePtr<RegisteredProxy> RegisterProxy(const wchar_t* aLeafName,
+ RegistrationFlags aFlags =
+ RegistrationFlags::eUseBinDirectory);
+// For standalone TLB files.
+UniquePtr<RegisteredProxy> RegisterTypelib(const wchar_t* aLeafName,
+ RegistrationFlags aFlags =
+ RegistrationFlags::eUseBinDirectory);
+
+/**
+ * The COM interceptor uses type library information to build its interface
+ * proxies. Unfortunately type libraries do not encode size_is and length_is
+ * annotations that have been specified in IDL. This structure allows us to
+ * explicitly declare such relationships so that the COM interceptor may
+ * be made aware of them.
+ */
+struct ArrayData
+{
+ ArrayData(REFIID aIid, ULONG aMethodIndex, ULONG aArrayParamIndex,
+ VARTYPE aArrayParamType, REFIID aArrayParamIid,
+ ULONG aLengthParamIndex)
+ : mIid(aIid)
+ , mMethodIndex(aMethodIndex)
+ , mArrayParamIndex(aArrayParamIndex)
+ , mArrayParamType(aArrayParamType)
+ , mArrayParamIid(aArrayParamIid)
+ , mLengthParamIndex(aLengthParamIndex)
+ {
+ }
+ ArrayData(const ArrayData& aOther)
+ {
+ *this = aOther;
+ }
+ ArrayData& operator=(const ArrayData& aOther)
+ {
+ mIid = aOther.mIid;
+ mMethodIndex = aOther.mMethodIndex;
+ mArrayParamIndex = aOther.mArrayParamIndex;
+ mArrayParamType = aOther.mArrayParamType;
+ mArrayParamIid = aOther.mArrayParamIid;
+ mLengthParamIndex = aOther.mLengthParamIndex;
+ return *this;
+ }
+ IID mIid;
+ ULONG mMethodIndex;
+ ULONG mArrayParamIndex;
+ VARTYPE mArrayParamType;
+ IID mArrayParamIid;
+ ULONG mLengthParamIndex;
+};
+
+void RegisterArrayData(const ArrayData* aArrayData, size_t aLength);
+
+template <size_t N>
+inline void
+RegisterArrayData(const ArrayData (&aData)[N])
+{
+ RegisterArrayData(aData, N);
+}
+
+const ArrayData*
+FindArrayData(REFIID aIid, ULONG aMethodIndex);
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Registration_h
+
--- a/ipc/mscom/moz.build
+++ b/ipc/mscom/moz.build
@@ -5,20 +5,22 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
EXPORTS.mozilla.mscom += [
'COMApartmentRegion.h',
'COMPtrHolder.h',
'EnsureMTA.h',
'MainThreadRuntime.h',
'ProxyStream.h',
+ 'Registration.h',
'Utils.h',
]
SOURCES += [
+ 'Registration.cpp',
'Utils.cpp',
]
UNIFIED_SOURCES += [
'EnsureMTA.cpp',
'MainThreadRuntime.cpp',
'ProxyStream.cpp',
]