Bug 1389980 - Ensure we only interact with WMF on MTA threads. r=mattwoodrow,r=aklotz draft
authorChris Pearce <cpearce@mozilla.com>
Tue, 15 Aug 2017 10:00:14 +1200
changeset 649625 a019a88f569c21b3ac4a97ea40e2329824dc1315
parent 649624 6a94c6da43a581702069d880fe480dfb668af7f0
child 727116 3e1e764c5ae8f4c5836096faba53e8d661f060ef
push id75074
push userbmo:cpearce@mozilla.com
push dateMon, 21 Aug 2017 01:29:11 +0000
reviewersmattwoodrow, aklotz
bugs1389980
milestone57.0a1
Bug 1389980 - Ensure we only interact with WMF on MTA threads. r=mattwoodrow,r=aklotz The IMFTransform interface used by MFTDecoder is documented to require to run on an MTA threads: https://msdn.microsoft.com/en-us/library/windows/desktop/ee892371(v=vs.85).aspx#components We're currently using IMFTransform objects on the main thread, which is STA. So delegate calls to the IMFTransform to the MTA thread when necessary, to ensure it always runs on an MTA thread. The existing uses of IMFTransform objects in the decode thread pool threads will be fine, as those threads are already MTA. We also defer initialization of WMF to the MTA thread, so that we're always interacting with WMF on an MTA thread. MozReview-Commit-ID: Dm8XpdvJLkS
dom/media/platforms/wmf/DXVA2Manager.cpp
dom/media/platforms/wmf/MFTDecoder.cpp
dom/media/platforms/wmf/WMF.h
dom/media/platforms/wmf/WMFDecoderModule.cpp
dom/media/platforms/wmf/WMFUtils.cpp
--- a/dom/media/platforms/wmf/DXVA2Manager.cpp
+++ b/dom/media/platforms/wmf/DXVA2Manager.cpp
@@ -19,16 +19,17 @@
 #include "mozilla/gfx/DeviceManagerDx.h"
 #include "mozilla/layers/D3D11ShareHandleImage.h"
 #include "mozilla/layers/ImageBridgeChild.h"
 #include "mozilla/layers/TextureForwarder.h"
 #include "mozilla/layers/TextureD3D11.h"
 #include "nsPrintfCString.h"
 #include "nsThreadUtils.h"
 #include "VideoUtils.h"
+#include "mozilla/mscom/EnsureMTA.h"
 
 const CLSID CLSID_VideoProcessorMFT =
 {
   0x88753b26,
   0x5b24,
   0x49bd,
   { 0xb2, 0xe7, 0xc, 0x44, 0x5c, 0x78, 0xc9, 0x82 }
 };
@@ -643,16 +644,17 @@ private:
   RefPtr<MFTDecoder> mTransform;
   RefPtr<D3D11RecycleAllocator> mTextureClientAllocator;
   RefPtr<ID3D11VideoDecoder> mDecoder;
   RefPtr<layers::SyncObjectClient> mSyncObject;
   GUID mDecoderGUID;
   uint32_t mWidth = 0;
   uint32_t mHeight = 0;
   UINT mDeviceManagerToken = 0;
+  bool mConfiuredForSize = false;
 };
 
 bool
 D3D11DXVA2Manager::SupportsConfig(IMFMediaType* aType, float aFramerate)
 {
   MOZ_ASSERT(NS_IsMainThread());
   D3D11_VIDEO_DECODER_DESC desc;
   desc.Guid = mDecoderGUID;
@@ -773,32 +775,45 @@ D3D11DXVA2Manager::InitInternal(layers::
 
   hr = mDXGIDeviceManager->ResetDevice(mDevice, mDeviceManagerToken);
   if (!SUCCEEDED(hr)) {
     aFailureReason = nsPrintfCString(
       "IMFDXGIDeviceManager::ResetDevice failed with code %X", hr);
     return hr;
   }
 
-  mTransform = new MFTDecoder();
-  hr = mTransform->Create(CLSID_VideoProcessorMFT);
+  // The IMFTransform interface used by MFTDecoder is documented to require to
+  // run on an MTA thread.
+  // https://msdn.microsoft.com/en-us/library/windows/desktop/ee892371(v=vs.85).aspx#components
+  // The main thread (where this function is called) is STA, not MTA.
+  RefPtr<MFTDecoder> mft;
+  mozilla::mscom::EnsureMTA([&]() -> void {
+    mft = new MFTDecoder();
+    hr = mft->Create(CLSID_VideoProcessorMFT);
+
+    if (!SUCCEEDED(hr)) {
+      aFailureReason = nsPrintfCString(
+        "MFTDecoder::Create(CLSID_VideoProcessorMFT) failed with code %X", hr);
+      return;
+    }
+
+    hr = mft->SendMFTMessage(MFT_MESSAGE_SET_D3D_MANAGER,
+                             ULONG_PTR(mDXGIDeviceManager.get()));
+    if (!SUCCEEDED(hr)) {
+      aFailureReason = nsPrintfCString("MFTDecoder::SendMFTMessage(MFT_MESSAGE_"
+                                       "SET_D3D_MANAGER) failed with code %X",
+                                       hr);
+      return;
+    }
+  });
+
   if (!SUCCEEDED(hr)) {
-    aFailureReason = nsPrintfCString(
-      "MFTDecoder::Create(CLSID_VideoProcessorMFT) failed with code %X", hr);
     return hr;
   }
-
-  hr = mTransform->SendMFTMessage(MFT_MESSAGE_SET_D3D_MANAGER,
-                                  ULONG_PTR(mDXGIDeviceManager.get()));
-  if (!SUCCEEDED(hr)) {
-    aFailureReason = nsPrintfCString("MFTDecoder::SendMFTMessage(MFT_MESSAGE_"
-                                     "SET_D3D_MANAGER) failed with code %X",
-                                     hr);
-    return hr;
-  }
+  mTransform = mft;
 
   RefPtr<ID3D11VideoDevice> videoDevice;
   hr = mDevice->QueryInterface(
     static_cast<ID3D11VideoDevice**>(getter_AddRefs(videoDevice)));
   if (!SUCCEEDED(hr)) {
     aFailureReason =
       nsPrintfCString("QI to ID3D11VideoDevice failed with code %X", hr);
     return hr;
@@ -942,24 +957,28 @@ D3D11DXVA2Manager::CopyToImage(IMFSample
       NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
       UINT index;
       dxgiBuf->GetSubresourceIndex(&index);
       mContext->CopySubresourceRegion(texture, 0, 0, 0, 0, tex, index, nullptr);
     } else {
       // Our video sample is in NV12 format but our output texture is in BGRA.
       // Use MFT to do color conversion.
-      hr = mTransform->Input(aVideoSample);
+      hr = E_FAIL;
+      mozilla::mscom::EnsureMTA(
+        [&]() -> void { hr = mTransform->Input(aVideoSample); });
       NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
       RefPtr<IMFSample> sample;
       hr = CreateOutputSample(sample, texture);
       NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
-      hr = mTransform->Output(&sample);
+      hr = E_FAIL;
+      mozilla::mscom::EnsureMTA(
+        [&]() -> void { hr = mTransform->Output(&sample); });
       NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
     }
   }
 
   if (!mutex && mDevice != DeviceManagerDx::Get()->GetCompositorDevice() && mSyncObject) {
     // It appears some race-condition may allow us to arrive here even when mSyncObject
     // is null. It's better to avoid that crash.
     client->SyncWithObject(mSyncObject);
@@ -1024,24 +1043,29 @@ D3D11DXVA2Manager::CopyToBGRATexture(ID3
 
   RefPtr<IMFMediaBuffer> inputBuffer;
   hr = wmf::MFCreateDXGISurfaceBuffer(
     __uuidof(ID3D11Texture2D), inTexture, 0, FALSE, getter_AddRefs(inputBuffer));
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   inputSample->AddBuffer(inputBuffer);
 
-  hr = mTransform->Input(inputSample);
+  hr = E_FAIL;
+  mozilla::mscom::EnsureMTA(
+    [&]() -> void { hr = mTransform->Input(inputSample); });
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   RefPtr<IMFSample> outputSample;
   hr = CreateOutputSample(outputSample, texture);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
-  hr = mTransform->Output(&outputSample);
+  hr = E_FAIL;
+  mozilla::mscom::EnsureMTA(
+    [&]() -> void { hr = mTransform->Output(&outputSample); });
+  NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   texture.forget(aOutTexture);
 
   return S_OK;
 }
 
 HRESULT ConfigureOutput(IMFMediaType* aOutput, void* aData)
 {
@@ -1057,16 +1081,21 @@ HRESULT ConfigureOutput(IMFMediaType* aO
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   return S_OK;
 }
 
 HRESULT
 D3D11DXVA2Manager::ConfigureForSize(uint32_t aWidth, uint32_t aHeight)
 {
+  if (mConfiuredForSize && aWidth == mWidth && aHeight == mHeight) {
+    // If the size hasn't changed, don't reconfigure.
+    return S_OK;
+  }
+
   mWidth = aWidth;
   mHeight = aHeight;
 
   RefPtr<IMFMediaType> inputType;
   HRESULT hr = wmf::MFCreateMediaType(getter_AddRefs(inputType));
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   hr = inputType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Video);
@@ -1076,17 +1105,20 @@ D3D11DXVA2Manager::ConfigureForSize(uint
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   hr = inputType->SetUINT32(MF_MT_INTERLACE_MODE, MFVideoInterlace_Progressive);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   hr = inputType->SetUINT32(MF_MT_ALL_SAMPLES_INDEPENDENT, TRUE);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
-  RefPtr<IMFAttributes> attr = mTransform->GetAttributes();
+  RefPtr<IMFAttributes> attr;
+  mozilla::mscom::EnsureMTA(
+    [&]() -> void { attr = mTransform->GetAttributes(); });
+  NS_ENSURE_TRUE(attr != nullptr, E_FAIL);
 
   hr = attr->SetUINT32(MF_XVP_PLAYBACK_MODE, TRUE);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   hr = attr->SetUINT32(MF_LOW_LATENCY, FALSE);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   hr = MFSetAttributeSize(inputType, MF_MT_FRAME_SIZE, aWidth, aHeight);
@@ -1098,19 +1130,25 @@ D3D11DXVA2Manager::ConfigureForSize(uint
 
   hr = outputType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Video);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   hr = outputType->SetGUID(MF_MT_SUBTYPE, MFVideoFormat_ARGB32);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   gfx::IntSize size(mWidth, mHeight);
-  hr = mTransform->SetMediaTypes(inputType, outputType, ConfigureOutput, &size);
+  hr = E_FAIL;
+  mozilla::mscom::EnsureMTA([&]() -> void {
+    hr =
+      mTransform->SetMediaTypes(inputType, outputType, ConfigureOutput, &size);
+  });
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
+  mConfiuredForSize = true;
+
   return S_OK;
 }
 
 bool
 D3D11DXVA2Manager::CreateDXVA2Decoder(const VideoInfo& aVideoInfo,
                                       nsACString& aFailureReason)
 {
   MOZ_ASSERT(NS_IsMainThread());
--- a/dom/media/platforms/wmf/MFTDecoder.cpp
+++ b/dom/media/platforms/wmf/MFTDecoder.cpp
@@ -3,16 +3,17 @@
 /* This Source Code Form is subject to the terms of the Mozilla Public
  * License, v. 2.0. If a copy of the MPL was not distributed with this
  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
 
 #include "MFTDecoder.h"
 #include "WMFUtils.h"
 #include "mozilla/Logging.h"
 #include "nsThreadUtils.h"
+#include "mozilla/mscom/Utils.h"
 
 #define LOG(...) MOZ_LOG(sPDMLog, mozilla::LogLevel::Debug, (__VA_ARGS__))
 
 namespace mozilla {
 
 MFTDecoder::MFTDecoder()
 {
   memset(&mInputStreamInfo, 0, sizeof(MFT_INPUT_STREAM_INFO));
@@ -21,16 +22,18 @@ MFTDecoder::MFTDecoder()
 
 MFTDecoder::~MFTDecoder()
 {
 }
 
 HRESULT
 MFTDecoder::Create(const GUID& aMFTClsID)
 {
+  // Note: IMFTransform is documented to only be safe on MTA threads.
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   // Create the IMFTransform to do the decoding.
   HRESULT hr;
   hr = CoCreateInstance(aMFTClsID,
                         nullptr,
                         CLSCTX_INPROC_SERVER,
                         IID_IMFTransform,
                         reinterpret_cast<void**>(static_cast<IMFTransform**>(
                           getter_AddRefs(mDecoder))));
@@ -40,16 +43,17 @@ MFTDecoder::Create(const GUID& aMFTClsID
 }
 
 HRESULT
 MFTDecoder::SetMediaTypes(IMFMediaType* aInputType,
                           IMFMediaType* aOutputType,
                           ConfigureOutputCallback aCallback,
                           void* aData)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   mOutputType = aOutputType;
 
   // Set the input type to the one the caller gave us...
   HRESULT hr = mDecoder->SetInputType(0, aInputType, 0);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   hr = SetDecoderOutputType(true /* match all attributes */, aCallback, aData);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
@@ -64,27 +68,29 @@ MFTDecoder::SetMediaTypes(IMFMediaType* 
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   return S_OK;
 }
 
 already_AddRefed<IMFAttributes>
 MFTDecoder::GetAttributes()
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   RefPtr<IMFAttributes> attr;
   HRESULT hr = mDecoder->GetAttributes(getter_AddRefs(attr));
   NS_ENSURE_TRUE(SUCCEEDED(hr), nullptr);
   return attr.forget();
 }
 
 HRESULT
 MFTDecoder::SetDecoderOutputType(bool aMatchAllAttributes,
                                  ConfigureOutputCallback aCallback,
                                  void* aData)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   NS_ENSURE_TRUE(mDecoder != nullptr, E_POINTER);
 
   GUID currentSubtype = {0};
   HRESULT hr = mOutputType->GetGUID(MF_MT_SUBTYPE, &currentSubtype);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   // Iterate the enumerate the output types, until we find one compatible
   // with what we need.
@@ -122,28 +128,30 @@ MFTDecoder::SetDecoderOutputType(bool aM
     outputType = nullptr;
   }
   return E_FAIL;
 }
 
 HRESULT
 MFTDecoder::SendMFTMessage(MFT_MESSAGE_TYPE aMsg, ULONG_PTR aData)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   NS_ENSURE_TRUE(mDecoder != nullptr, E_POINTER);
   HRESULT hr = mDecoder->ProcessMessage(aMsg, aData);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
   return S_OK;
 }
 
 HRESULT
 MFTDecoder::CreateInputSample(const uint8_t* aData,
                               uint32_t aDataSize,
                               int64_t aTimestamp,
                               RefPtr<IMFSample>* aOutSample)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   NS_ENSURE_TRUE(mDecoder != nullptr, E_POINTER);
 
   HRESULT hr;
   RefPtr<IMFSample> sample;
   hr = wmf::MFCreateSample(getter_AddRefs(sample));
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   RefPtr<IMFMediaBuffer> buffer;
@@ -179,16 +187,17 @@ MFTDecoder::CreateInputSample(const uint
   *aOutSample = sample.forget();
 
   return S_OK;
 }
 
 HRESULT
 MFTDecoder::CreateOutputSample(RefPtr<IMFSample>* aOutSample)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   NS_ENSURE_TRUE(mDecoder != nullptr, E_POINTER);
 
   HRESULT hr;
   RefPtr<IMFSample> sample;
   hr = wmf::MFCreateSample(getter_AddRefs(sample));
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   RefPtr<IMFMediaBuffer> buffer;
@@ -205,16 +214,17 @@ MFTDecoder::CreateOutputSample(RefPtr<IM
   *aOutSample = sample.forget();
 
   return S_OK;
 }
 
 HRESULT
 MFTDecoder::Output(RefPtr<IMFSample>* aOutput)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   NS_ENSURE_TRUE(mDecoder != nullptr, E_POINTER);
 
   HRESULT hr;
 
   MFT_OUTPUT_DATA_BUFFER output = {0};
 
   bool providedSample = false;
   RefPtr<IMFSample> sample;
@@ -270,49 +280,53 @@ MFTDecoder::Output(RefPtr<IMFSample>* aO
   return S_OK;
 }
 
 HRESULT
 MFTDecoder::Input(const uint8_t* aData,
                   uint32_t aDataSize,
                   int64_t aTimestamp)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   NS_ENSURE_TRUE(mDecoder != nullptr, E_POINTER);
 
   RefPtr<IMFSample> input;
   HRESULT hr = CreateInputSample(aData, aDataSize, aTimestamp, &input);
   NS_ENSURE_TRUE(SUCCEEDED(hr) && input != nullptr, hr);
 
   return Input(input);
 }
 
 HRESULT
 MFTDecoder::Input(IMFSample* aSample)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   HRESULT hr = mDecoder->ProcessInput(0, aSample, 0);
   if (hr == MF_E_NOTACCEPTING) {
     // MFT *already* has enough data to produce a sample. Retrieve it.
     return MF_E_NOTACCEPTING;
   }
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   return S_OK;
 }
 
 HRESULT
 MFTDecoder::Flush()
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   HRESULT hr = SendMFTMessage(MFT_MESSAGE_COMMAND_FLUSH, 0);
   NS_ENSURE_TRUE(SUCCEEDED(hr), hr);
 
   mDiscontinuity = true;
 
   return S_OK;
 }
 
 HRESULT
 MFTDecoder::GetOutputMediaType(RefPtr<IMFMediaType>& aMediaType)
 {
+  MOZ_ASSERT(mscom::IsCurrentThreadMTA());
   NS_ENSURE_TRUE(mDecoder, E_POINTER);
   return mDecoder->GetOutputCurrentType(0, getter_AddRefs(aMediaType));
 }
 
 } // namespace mozilla
--- a/dom/media/platforms/wmf/WMF.h
+++ b/dom/media/platforms/wmf/WMF.h
@@ -37,22 +37,27 @@
 extern "C" const CLSID CLSID_CMSAACDecMFT;
 #define WMF_MUST_DEFINE_AAC_MFT_CLSID
 #endif
 
 namespace mozilla {
 namespace wmf {
 
 // If successful, loads all required WMF DLLs and calls the WMF MFStartup()
-// function.
+// function. This delegates the WMF MFStartup() call to the MTA thread if
+// the current thread is not MTA. This is to ensure we always interact with
+// WMF from threads with the same COM compartment model.
 HRESULT MFStartup();
 
 // Calls the WMF MFShutdown() function. Call this once for every time
 // wmf::MFStartup() succeeds. Note: does not unload the WMF DLLs loaded by
 // MFStartup(); leaves them in memory to save I/O at next MFStartup() call.
+// This delegates the WMF MFShutdown() call to the MTA thread if the current
+// thread is not MTA. This is to ensure we always interact with
+// WMF from threads with the same COM compartment model.
 HRESULT MFShutdown();
 
 // All functions below are wrappers around the corresponding WMF function,
 // and automatically locate and call the corresponding function in the WMF DLLs.
 
 HRESULT MFCreateMediaType(IMFMediaType **aOutMFType);
 
 HRESULT MFGetStrideForBitmapInfoHeader(DWORD aFormat,
--- a/dom/media/platforms/wmf/WMFDecoderModule.cpp
+++ b/dom/media/platforms/wmf/WMFDecoderModule.cpp
@@ -24,16 +24,17 @@
 #include "nsAutoPtr.h"
 #include "nsComponentManagerUtils.h"
 #include "nsIGfxInfo.h"
 #include "nsIWindowsRegKey.h"
 #include "nsServiceManagerUtils.h"
 #include "nsWindowsHelpers.h"
 #include "prsystem.h"
 #include "nsIXULRuntime.h"
+#include "mozilla/mscom/EnsureMTA.h"
 
 extern const GUID CLSID_WebmMfVpxDec;
 
 namespace mozilla {
 
 static Atomic<bool> sDXVAEnabled(false);
 
 WMFDecoderModule::~WMFDecoderModule()
@@ -130,26 +131,32 @@ WMFDecoderModule::CreateAudioDecoder(con
   RefPtr<MediaDataDecoder> decoder =
     new WMFMediaDataDecoder(manager.forget(), aParams.mTaskQueue);
   return decoder.forget();
 }
 
 static bool
 CanCreateMFTDecoder(const GUID& aGuid)
 {
-  if (FAILED(wmf::MFStartup())) {
-    return false;
-  }
-  bool hasdecoder = false;
-  {
+  // The IMFTransform interface used by MFTDecoder is documented to require to
+  // run on an MTA thread.
+  // https://msdn.microsoft.com/en-us/library/windows/desktop/ee892371(v=vs.85).aspx#components
+  // Note: our normal SharedThreadPool task queues are initialized to MTA, but
+  // the main thread (which calls in here from our CanPlayType implementation)
+  // is not.
+  bool canCreateDecoder = false;
+  mozilla::mscom::EnsureMTA([&]() -> void {
+    if (FAILED(wmf::MFStartup())) {
+      return;
+    }
     RefPtr<MFTDecoder> decoder(new MFTDecoder());
-    hasdecoder = SUCCEEDED(decoder->Create(aGuid));
-  }
-  wmf::MFShutdown();
-  return hasdecoder;
+    canCreateDecoder = SUCCEEDED(decoder->Create(aGuid));
+    wmf::MFShutdown();
+  });
+  return canCreateDecoder;
 }
 
 template<const GUID& aGuid>
 static bool
 CanCreateWMFDecoder()
 {
   static StaticMutex sMutex;
   StaticMutexAutoLock lock(sMutex);
--- a/dom/media/platforms/wmf/WMFUtils.cpp
+++ b/dom/media/platforms/wmf/WMFUtils.cpp
@@ -10,16 +10,17 @@
 #include "mozilla/CheckedInt.h"
 #include "mozilla/Logging.h"
 #include "mozilla/RefPtr.h"
 #include "nsTArray.h"
 #include "nsThreadUtils.h"
 #include "nsWindowsHelpers.h"
 #include <initguid.h>
 #include <stdint.h>
+#include "mozilla/mscom/EnsureMTA.h"
 
 #ifdef WMF_MUST_DEFINE_AAC_MFT_CLSID
 // Some SDK versions don't define the AAC decoder CLSID.
 // {32D186A7-218F-4C75-8876-DD77273A8999}
 DEFINE_GUID(CLSID_CMSAACDecMFT, 0x32D186A7, 0x218F, 0x4C75, 0x88, 0x76, 0xDD, 0x77, 0x27, 0x3A, 0x89, 0x99);
 #endif
 
 namespace mozilla {
@@ -215,24 +216,30 @@ MFStartup()
     return hr;
   }
 
   const int MF_WIN7_VERSION = (0x0002 << 16 | MF_API_VERSION);
 
   // decltype is unusable for functions having default parameters
   DECL_FUNCTION_PTR(MFStartup, ULONG, DWORD);
   ENSURE_FUNCTION_PTR_(MFStartup, Mfplat.dll)
-  return MFStartupPtr(MF_WIN7_VERSION, MFSTARTUP_FULL);
+
+  hr = E_FAIL;
+  mozilla::mscom::EnsureMTA(
+    [&]() -> void { hr = MFStartupPtr(MF_WIN7_VERSION, MFSTARTUP_FULL); });
+  return hr;
 }
 
 HRESULT
 MFShutdown()
 {
   ENSURE_FUNCTION_PTR(MFShutdown, Mfplat.dll)
-  return (MFShutdownPtr)();
+  HRESULT hr = E_FAIL;
+  mozilla::mscom::EnsureMTA([&]() -> void { hr = (MFShutdownPtr)(); });
+  return hr;
 }
 
 HRESULT
 MFCreateMediaType(IMFMediaType **aOutMFType)
 {
   ENSURE_FUNCTION_PTR(MFCreateMediaType, Mfplat.dll)
   return (MFCreateMediaTypePtr)(aOutMFType);
 }