Bug 1457728: Support tri-state comparators in nsTArray. r?erahm draft
authorKris Maglione <maglione.k@gmail.com>
Sat, 28 Apr 2018 15:15:04 -0700
changeset 814332 e2384e8f98b76a55aece2358528e9d9cf0272e80
parent 814331 b1e3cd7db269d4cc9c61e754a95169da46213dc2
push id115158
push usermaglione.k@gmail.com
push dateWed, 04 Jul 2018 23:01:20 +0000
reviewerserahm
bugs1457728
milestone63.0a1
Bug 1457728: Support tri-state comparators in nsTArray. r?erahm MozReview-Commit-ID: FC9GPsJJ03K
dom/base/MozQueryInterface.cpp
xpcom/ds/nsTArray.h
xpcom/tests/gtest/TestTArray2.cpp
--- a/dom/base/MozQueryInterface.cpp
+++ b/dom/base/MozQueryInterface.cpp
@@ -23,31 +23,16 @@ static_assert(IID_SIZE == 16,
               "Size of nsID struct changed. Please ensure this code is still valid.");
 
 static int
 CompareIIDs(const nsIID& aA, const nsIID &aB)
 {
   return memcmp((void*)&aA.m0, (void*)&aB.m0, IID_SIZE);
 }
 
-struct IIDComparator
-{
-  bool
-  LessThan(const nsIID& aA, const nsIID &aB) const
-  {
-    return CompareIIDs(aA, aB) < 0;
-  }
-
-  bool
-  Equals(const nsIID& aA, const nsIID &aB) const
-  {
-    return aA.Equals(aB);
-  }
-};
-
 /* static */
 MozQueryInterface*
 ChromeUtils::GenerateQI(const GlobalObject& aGlobal, const Sequence<OwningStringOrIID>& aInterfaces, ErrorResult& aRv)
 {
   JSContext* cx = aGlobal.Context();
   JS::RootedObject xpcIfaces(cx);
 
   nsTArray<nsIID> ifaces;
@@ -93,33 +78,28 @@ ChromeUtils::GenerateQI(const GlobalObje
     nsCOMPtr<nsIJSID> iid = do_QueryInterface(base);
     if (!iid) {
       aRv.Throw(NS_ERROR_INVALID_ARG);
       return nullptr;
     }
     ifaces.AppendElement(*iid->GetID());
   }
 
-  MOZ_ASSERT(!ifaces.Contains(NS_GET_IID(nsISupports), IIDComparator()));
+  MOZ_ASSERT(!ifaces.Contains(NS_GET_IID(nsISupports), CompareIIDs));
   ifaces.AppendElement(NS_GET_IID(nsISupports));
 
-  ifaces.Sort(IIDComparator());
+  ifaces.Sort(CompareIIDs);
 
   return new MozQueryInterface(std::move(ifaces));
 }
 
 bool
 MozQueryInterface::QueriesTo(const nsIID& aIID) const
 {
-  // We use BinarySearchIf here because nsTArray::ContainsSorted requires
-  // twice as many comparisons.
-  size_t result;
-  return BinarySearchIf(mInterfaces, 0, mInterfaces.Length(),
-                        [&] (const nsIID& aOther) { return CompareIIDs(aIID, aOther); },
-                        &result);
+  return mInterfaces.ContainsSorted(aIID, CompareIIDs);
 }
 
 void
 MozQueryInterface::LegacyCall(JSContext* cx, JS::Handle<JS::Value> thisv,
                               nsIJSID* aIID,
                               JS::MutableHandle<JS::Value> aResult,
                               ErrorResult& aRv) const
 {
--- a/xpcom/ds/nsTArray.h
+++ b/xpcom/ds/nsTArray.h
@@ -767,53 +767,103 @@ struct nsTArray_TypedBase<JS::Heap<E>, D
   {
     Derived* self = static_cast<Derived*>(this);
     return *reinterpret_cast<FallibleTArray<E> *>(self);
   }
 };
 
 namespace detail {
 
-template<class Item, class Comparator>
-struct ItemComparatorEq
+// These helpers allow us to differentiate between tri-state comparator
+// functions and classes with LessThan() and Equal() methods. If an object, when
+// called as a function with two instances of our element type, returns an int,
+// we treat it as a tri-state comparator.
+//
+// T is the type of the comparator object we want to check. U is the array
+// element type that we'll be comparing.
+//
+// V is never passed, and is only used to allow us to specialize on the return
+// value of the comparator function.
+template <typename T, typename U, typename V = int>
+struct IsCompareMethod : mozilla::FalseType {};
+
+template <typename T, typename U>
+struct IsCompareMethod<T, U, decltype(mozilla::DeclVal<T>()(mozilla::DeclVal<U>(), mozilla::DeclVal<U>()))>
+  : mozilla::TrueType {};
+
+// These two wrappers allow us to use either a tri-state comparator, or an
+// object with Equals() and LessThan() methods interchangeably. They provide a
+// tri-state Compare() method, and Equals() method, and a LessThan() method.
+//
+// Depending on the type of the underlying comparator, they either pass these
+// through directly, or synthesize them from the methods available on the
+// comparator.
+//
+// Callers should always use the most-specific of these methods that match their
+// purpose.
+
+// Comparator wrapper for a tri-state comparator function
+template <typename T, typename U, bool IsCompare = IsCompareMethod<T, U>::value>
+struct CompareWrapper
 {
-  const Item& mItem;
-  const Comparator& mComp;
-  ItemComparatorEq(const Item& aItem, const Comparator& aComp)
-    : mItem(aItem)
-    , mComp(aComp)
+  MOZ_IMPLICIT CompareWrapper(const T& aComparator)
+    : mComparator(aComparator)
   {}
-  template<class T>
-  int operator()(const T& aElement) const {
-    if (mComp.Equals(aElement, mItem)) {
-      return 0;
-    }
+
+  template <typename A, typename B>
+  int Compare(A& aLeft, B& aRight) const
+  {
+    return mComparator(aLeft, aRight);
+  }
 
-    return mComp.LessThan(aElement, mItem) ? 1 : -1;
+  template <typename A, typename B>
+  bool Equals(A& aLeft, B& aRight) const
+  {
+    return Compare(aLeft, aRight) == 0;
   }
+
+  template <typename A, typename B>
+  bool LessThan(A& aLeft, B& aRight) const
+  {
+    return Compare(aLeft, aRight) < 0;
+  }
+
+  const T& mComparator;
 };
 
-template<class Item, class Comparator>
-struct ItemComparatorFirstElementGT
+// Comparator wrapper for a class with Equals() and LessThan() methods.
+template <typename T, typename U>
+struct CompareWrapper<T, U, false>
 {
-  const Item& mItem;
-  const Comparator& mComp;
-  ItemComparatorFirstElementGT(const Item& aItem, const Comparator& aComp)
-    : mItem(aItem)
-    , mComp(aComp)
+  MOZ_IMPLICIT CompareWrapper(const T& aComparator)
+    : mComparator(aComparator)
   {}
-  template<class T>
-  int operator()(const T& aElement) const {
-    if (mComp.LessThan(aElement, mItem) ||
-        mComp.Equals(aElement, mItem)) {
-      return 1;
-    } else {
-      return -1;
+
+  template <typename A, typename B>
+  int Compare(A& aLeft, B& aRight) const
+  {
+    if (Equals(aLeft, aRight)) {
+      return 0;
     }
+    return LessThan(aLeft, aRight) ? -1 : 1;
   }
+
+  template <typename A, typename B>
+  bool Equals(A& aLeft, B& aRight) const
+  {
+    return mComparator.Equals(aLeft, aRight);
+  }
+
+  template <typename A, typename B>
+  bool LessThan(A& aLeft, B& aRight) const
+  {
+    return mComparator.LessThan(aLeft, aRight);
+  }
+
+  const T& mComparator;
 };
 
 } // namespace detail
 
 //
 // nsTArray_Impl contains most of the guts supporting nsTArray, FallibleTArray,
 // AutoTArray.
 //
@@ -1162,20 +1212,22 @@ public:
   // @param aItem  The item to search for.
   // @param aStart The index to start from.
   // @param aComp  The Comparator used to determine element equality.
   // @return       The index of the found element or NoIndex if not found.
   template<class Item, class Comparator>
   index_type IndexOf(const Item& aItem, index_type aStart,
                      const Comparator& aComp) const
   {
+    ::detail::CompareWrapper<Comparator, Item> comp(aComp);
+
     const elem_type* iter = Elements() + aStart;
     const elem_type* iend = Elements() + Length();
     for (; iter != iend; ++iter) {
-      if (aComp.Equals(*iter, aItem)) {
+      if (comp.Equals(*iter, aItem)) {
         return index_type(iter - Elements());
       }
     }
     return NoIndex;
   }
 
   // This method searches for the offset of the first element in this
   // array that is equal to the given element.  This method assumes
@@ -1195,21 +1247,23 @@ public:
   // @param aStart The index to start from.  If greater than or equal to the
   //               length of the array, then the entire array is searched.
   // @param aComp  The Comparator used to determine element equality.
   // @return       The index of the found element or NoIndex if not found.
   template<class Item, class Comparator>
   index_type LastIndexOf(const Item& aItem, index_type aStart,
                          const Comparator& aComp) const
   {
+    ::detail::CompareWrapper<Comparator, Item> comp(aComp);
+
     size_type endOffset = aStart >= Length() ? Length() : aStart + 1;
     const elem_type* iend = Elements() - 1;
     const elem_type* iter = iend + endOffset;
     for (; iter != iend; --iter) {
-      if (aComp.Equals(*iter, aItem)) {
+      if (comp.Equals(*iter, aItem)) {
         return index_type(iter - Elements());
       }
     }
     return NoIndex;
   }
 
   // This method searches for the offset of the last element in this
   // array that is equal to the given element.  This method assumes
@@ -1231,20 +1285,30 @@ public:
   // on which one will be returned.
   // @param aItem  The item to search for.
   // @param aComp  The Comparator used.
   // @return       The index of the found element or NoIndex if not found.
   template<class Item, class Comparator>
   index_type BinaryIndexOf(const Item& aItem, const Comparator& aComp) const
   {
     using mozilla::BinarySearchIf;
-    typedef ::detail::ItemComparatorEq<Item, Comparator> Cmp;
+    ::detail::CompareWrapper<Comparator, Item> comp(aComp);
 
     size_t index;
-    bool found = BinarySearchIf(*this, 0, Length(), Cmp(aItem, aComp), &index);
+    bool found = BinarySearchIf(
+      *this, 0, Length(),
+      // Note: We pass the Compare() args here in reverse order and negate the
+      // results for compatibility reasons. Some existing callers use Equals()
+      // functions with first arguments which match aElement but not aItem, or
+      // second arguments that match aItem but not aElement. To accommodate
+      // those callers, we preserve the argument order of the older version of
+      // this API. These callers, however, should be fixed, and this special
+      // case removed.
+      [&] (const elem_type& aElement) { return -comp.Compare(aElement, aItem); },
+      &index);
     return found ? index : NoIndex;
   }
 
   // This method searches for the offset for the element in this array
   // that is equal to the given element. The array is assumed to be sorted.
   // This method assumes that 'operator==' and 'operator<' are defined.
   // @param aItem  The item to search for.
   // @return       The index of the found element or NoIndex if not found.
@@ -1529,20 +1593,22 @@ public:
   // @param aComp  The Comparator used.
   // @return        The index of greatest element <= to |aItem|
   // @precondition The array is sorted
   template<class Item, class Comparator>
   index_type IndexOfFirstElementGt(const Item& aItem,
                                    const Comparator& aComp) const
   {
     using mozilla::BinarySearchIf;
-    typedef ::detail::ItemComparatorFirstElementGT<Item, Comparator> Cmp;
+    ::detail::CompareWrapper<Comparator, Item> comp(aComp);
 
     size_t index;
-    BinarySearchIf(*this, 0, Length(), Cmp(aItem, aComp), &index);
+    BinarySearchIf(*this, 0, Length(),
+                   [&] (const elem_type& aElement) { return comp.Compare(aElement, aItem) <= 0 ? 1 : -1; },
+                   &index);
     return index;
   }
 
   // A variation on the IndexOfFirstElementGt method defined above.
   template<class Item>
   index_type
   IndexOfFirstElementGt(const Item& aItem) const
   {
@@ -2028,27 +2094,29 @@ public:
   // maps the callback API expected by NS_QuickSort to the Comparator API
   // used by nsTArray_Impl.  See nsTArray_Impl::Sort.
   template<class Comparator>
   static int Compare(const void* aE1, const void* aE2, void* aData)
   {
     const Comparator* c = reinterpret_cast<const Comparator*>(aData);
     const elem_type* a = static_cast<const elem_type*>(aE1);
     const elem_type* b = static_cast<const elem_type*>(aE2);
-    return c->LessThan(*a, *b) ? -1 : (c->Equals(*a, *b) ? 0 : 1);
+    return c->Compare(*a, *b);
   }
 
   // This method sorts the elements of the array.  It uses the LessThan
   // method defined on the given Comparator object to collate elements.
   // @param aComp The Comparator used to collate elements.
   template<class Comparator>
   void Sort(const Comparator& aComp)
   {
+    ::detail::CompareWrapper<Comparator, elem_type> comp(aComp);
+
     NS_QuickSort(Elements(), Length(), sizeof(elem_type),
-                 Compare<Comparator>, const_cast<Comparator*>(&aComp));
+                 Compare<decltype(comp)>, &comp);
   }
 
   // A variation on the Sort method defined above that assumes that
   // 'operator<' is defined for elem_type.
   void Sort() { Sort(nsDefaultComparator<elem_type, elem_type>()); }
 
   // This method reverses the array in place.
   void Reverse()
--- a/xpcom/tests/gtest/TestTArray2.cpp
+++ b/xpcom/tests/gtest/TestTArray2.cpp
@@ -1034,9 +1034,56 @@ TEST(TArray, test_SetLengthAndRetainStor
   }
 
 
 #undef FOR_EACH
 #undef LPAREN
 #undef RPAREN
 }
 
+template <typename Comparator>
+bool
+TestCompareMethods(const Comparator& aComp)
+{
+  nsTArray<int> ary({57, 4, 16, 17, 3, 5, 96, 12});
+
+  ary.Sort(aComp);
+
+  const int sorted[] = {3, 4, 5, 12, 16, 17, 57, 96 };
+  for (size_t i = 0; i < MOZ_ARRAY_LENGTH(sorted); i++) {
+    if (sorted[i] != ary[i]) {
+      return false;
+    }
+  }
+
+  if (!ary.ContainsSorted(5, aComp)) {
+    return false;
+  }
+  if (ary.ContainsSorted(42, aComp)) {
+    return false;
+  }
+
+  if (ary.BinaryIndexOf(16, aComp) != 4) {
+    return false;
+  }
+
+  return true;
+}
+
+struct IntComparator
+{
+  bool Equals(int aLeft, int aRight) const
+  {
+    return aLeft == aRight;
+  }
+
+  bool LessThan(int aLeft, int aRight) const
+  {
+    return aLeft < aRight;
+  }
+};
+
+TEST(TArray, test_comparator_objects) {
+  ASSERT_TRUE(TestCompareMethods(IntComparator()));
+  ASSERT_TRUE(TestCompareMethods([] (int aLeft, int aRight) { return aLeft - aRight; }));
+}
+
 } // namespace TestTArray