--- a/ipc/mscom/Registration.cpp
+++ b/ipc/mscom/Registration.cpp
@@ -27,45 +27,74 @@
#include <oaidl.h>
#include <objidl.h>
#include <rpcproxy.h>
#include <shlwapi.h>
/* This code MUST NOT use any non-inlined internal Mozilla APIs, as it will be
compiled into DLLs that COM may load into non-Mozilla processes! */
-namespace {
+extern "C" {
// This function is defined in generated code for proxy DLLs but is not declared
-// in rpcproxy.h, so we need this typedef.
-typedef void (RPC_ENTRY *GetProxyDllInfoFnPtr)(const ProxyFileInfo*** aInfo,
- const CLSID** aId);
+// in rpcproxy.h, so we need this declaration.
+void RPC_ENTRY GetProxyDllInfo(const ProxyFileInfo*** aInfo, const CLSID** aId);
-} // anonymous namespace
+#if defined(_MSC_VER)
+extern IMAGE_DOS_HEADER __ImageBase;
+#endif
+
+}
namespace mozilla {
namespace mscom {
+static HMODULE
+GetContainingModule()
+{
+ HMODULE thisModule = nullptr;
+#if defined(_MSC_VER)
+ thisModule = reinterpret_cast<HMODULE>(&__ImageBase);
+#else
+ if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
+ GET_MODULE_HANDLE_EX_UNCHANGED_REFCOUNT,
+ reinterpret_cast<LPCTSTR>(&GetContainingModule),
+ &thisModule)) {
+ return nullptr;
+ }
+#endif
+ return thisModule;
+}
+
+static bool
+GetContainingLibPath(wchar_t* aBuffer, size_t aBufferLen)
+{
+ HMODULE thisModule = GetContainingModule();
+ if (!thisModule) {
+ return false;
+ }
+
+ DWORD fileNameResult = GetModuleFileName(thisModule, aBuffer, aBufferLen);
+ if (!fileNameResult || (fileNameResult == aBufferLen &&
+ ::GetLastError() == ERROR_INSUFFICIENT_BUFFER)) {
+ return false;
+ }
+
+ return true;
+}
+
static bool
BuildLibPath(RegistrationFlags aFlags, wchar_t* aBuffer, size_t aBufferLen,
const wchar_t* aLeafName)
{
if (aFlags == RegistrationFlags::eUseBinDirectory) {
- HMODULE thisModule = nullptr;
- if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
- GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
- reinterpret_cast<LPCTSTR>(&RegisterProxy),
- &thisModule)) {
+ if (!GetContainingLibPath(aBuffer, aBufferLen)) {
return false;
}
- DWORD fileNameResult = GetModuleFileName(thisModule, aBuffer, aBufferLen);
- if (!fileNameResult || (fileNameResult == aBufferLen &&
- ::GetLastError() == ERROR_INSUFFICIENT_BUFFER)) {
- return false;
- }
+
if (!PathRemoveFileSpec(aBuffer)) {
return false;
}
} else if (aFlags == RegistrationFlags::eUseSystemDirectory) {
UINT result = GetSystemDirectoryW(aBuffer, static_cast<UINT>(aBufferLen));
if (!result || result > aBufferLen) {
return false;
}
@@ -74,16 +103,84 @@ BuildLibPath(RegistrationFlags aFlags, w
}
if (!PathAppend(aBuffer, aLeafName)) {
return false;
}
return true;
}
+static bool
+RegisterPSClsids(const ProxyFileInfo** aProxyInfo, const CLSID* aProxyClsid)
+{
+ while (*aProxyInfo) {
+ const ProxyFileInfo& curInfo = **aProxyInfo;
+ for (unsigned short idx = 0, size = curInfo.TableSize; idx < size; ++idx) {
+ HRESULT hr = CoRegisterPSClsid(*(curInfo.pStubVtblList[idx]->header.piid),
+ *aProxyClsid);
+ if (FAILED(hr)) {
+ return false;
+ }
+ }
+ ++aProxyInfo;
+ }
+
+ return true;
+}
+
+UniquePtr<RegisteredProxy>
+RegisterProxy()
+{
+ const ProxyFileInfo** proxyInfo = nullptr;
+ const CLSID* proxyClsid = nullptr;
+ GetProxyDllInfo(&proxyInfo, &proxyClsid);
+ if (!proxyInfo || !proxyClsid) {
+ return nullptr;
+ }
+
+ IUnknown* classObject = nullptr;
+ HRESULT hr = DllGetClassObject(*proxyClsid, IID_IUnknown, (void**)&classObject);
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+
+ DWORD regCookie;
+ hr = CoRegisterClassObject(*proxyClsid, classObject, CLSCTX_INPROC_SERVER,
+ REGCLS_MULTIPLEUSE, ®Cookie);
+ if (FAILED(hr)) {
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+ wchar_t modulePathBuf[MAX_PATH + 1] = {0};
+ if (!GetContainingLibPath(modulePathBuf, ArrayLength(modulePathBuf))) {
+ CoRevokeClassObject(regCookie);
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+ ITypeLib* typeLib = nullptr;
+ hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ CoRevokeClassObject(regCookie);
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+ // RegisteredProxy takes ownership of classObject and typeLib references
+ auto result(MakeUnique<RegisteredProxy>(classObject, regCookie, typeLib));
+
+ if (!RegisterPSClsids(proxyInfo, proxyClsid)) {
+ return nullptr;
+ }
+
+ return result;
+}
+
UniquePtr<RegisteredProxy>
RegisterProxy(const wchar_t* aLeafName, RegistrationFlags aFlags)
{
wchar_t modulePathBuf[MAX_PATH + 1] = {0};
if (!BuildLibPath(aFlags, modulePathBuf, ArrayLength(modulePathBuf),
aLeafName)) {
return nullptr;
}
@@ -95,17 +192,17 @@ RegisterProxy(const wchar_t* aLeafName,
// Instantiate an activation context so that CoGetClassObject will use any
// COM metadata embedded in proxyDll's manifest to resolve CLSIDs.
ActivationContext actCtx(proxyDll);
if (!actCtx) {
return nullptr;
}
- auto GetProxyDllInfoFn = reinterpret_cast<GetProxyDllInfoFnPtr>(
+ auto GetProxyDllInfoFn = reinterpret_cast<decltype(&GetProxyDllInfo)>(
GetProcAddress(proxyDll, "GetProxyDllInfo"));
if (!GetProxyDllInfoFn) {
return nullptr;
}
const ProxyFileInfo** proxyInfo = nullptr;
const CLSID* proxyClsid = nullptr;
GetProxyDllInfoFn(&proxyInfo, &proxyClsid);
@@ -139,26 +236,18 @@ RegisterProxy(const wchar_t* aLeafName,
return nullptr;
}
// RegisteredProxy takes ownership of proxyDll, classObject, and typeLib
// references
auto result(MakeUnique<RegisteredProxy>(reinterpret_cast<uintptr_t>(proxyDll.disown()),
classObject, regCookie, typeLib));
- while (*proxyInfo) {
- const ProxyFileInfo& curInfo = **proxyInfo;
- for (unsigned short i = 0, e = curInfo.TableSize; i < e; ++i) {
- hr = CoRegisterPSClsid(*(curInfo.pStubVtblList[i]->header.piid),
- *proxyClsid);
- if (FAILED(hr)) {
- return nullptr;
- }
- }
- ++proxyInfo;
+ if (!RegisterPSClsids(proxyInfo, proxyClsid)) {
+ return nullptr;
}
return result;
}
UniquePtr<RegisteredProxy>
RegisterTypelib(const wchar_t* aLeafName, RegistrationFlags aFlags)
{
@@ -187,16 +276,29 @@ RegisteredProxy::RegisteredProxy(uintptr
, mTypeLib(aTypeLib)
, mIsRegisteredInMTA(IsCurrentThreadMTA())
{
MOZ_ASSERT(aClassObject);
MOZ_ASSERT(aTypeLib);
AddToRegistry(this);
}
+RegisteredProxy::RegisteredProxy(IUnknown* aClassObject, uint32_t aRegCookie,
+ ITypeLib* aTypeLib)
+ : mModule(0)
+ , mClassObject(aClassObject)
+ , mRegCookie(aRegCookie)
+ , mTypeLib(aTypeLib)
+ , mIsRegisteredInMTA(IsCurrentThreadMTA())
+{
+ MOZ_ASSERT(aClassObject);
+ MOZ_ASSERT(aTypeLib);
+ AddToRegistry(this);
+}
+
// If we're initializing from a typelib, it doesn't matter which apartment we
// run in, so mIsRegisteredInMTA may always be set to false in this case.
RegisteredProxy::RegisteredProxy(ITypeLib* aTypeLib)
: mModule(0)
, mClassObject(nullptr)
, mRegCookie(0)
, mTypeLib(aTypeLib)
, mIsRegisteredInMTA(false)
--- a/ipc/mscom/Registration.h
+++ b/ipc/mscom/Registration.h
@@ -23,16 +23,18 @@ namespace mscom {
* (1) The DLL exports GetProxyDllInfo. This is not exported by default; it must
* be specified in the EXPORTS section of the DLL's module definition file.
*/
class RegisteredProxy
{
public:
RegisteredProxy(uintptr_t aModule, IUnknown* aClassObject,
uint32_t aRegCookie, ITypeLib* aTypeLib);
+ RegisteredProxy(IUnknown* aClassObject, uint32_t aRegCookie,
+ ITypeLib* aTypeLib);
explicit RegisteredProxy(ITypeLib* aTypeLib);
RegisteredProxy(RegisteredProxy&& aOther);
RegisteredProxy& operator=(RegisteredProxy&& aOther);
~RegisteredProxy();
HRESULT GetTypeInfoForInterface(REFIID aIid, ITypeInfo** aOutTypeInfo) const;
@@ -58,16 +60,20 @@ private:
};
enum class RegistrationFlags
{
eUseBinDirectory,
eUseSystemDirectory
};
+// For our own DLL that we are currently executing in (ie, xul).
+// Assumes corresponding TLB is embedded in resources.
+UniquePtr<RegisteredProxy> RegisterProxy();
+
// For DLL files. Assumes corresponding TLB is embedded in resources.
UniquePtr<RegisteredProxy> RegisterProxy(const wchar_t* aLeafName,
RegistrationFlags aFlags =
RegistrationFlags::eUseBinDirectory);
// For standalone TLB files.
UniquePtr<RegisteredProxy> RegisterTypelib(const wchar_t* aLeafName,
RegistrationFlags aFlags =
RegistrationFlags::eUseBinDirectory);