Bug 1230311 - clang-plugin - static analysis to enforce that if a method is overridden also base method is called. r?mystor draft
authorAndi-Bogdan Postelnicu <bpostelnicu@mozilla.com>
Thu, 06 Oct 2016 11:00:29 +0300
changeset 421475 88a6c6caea53df850160a5f07270f8bfefb7828f
parent 421441 c7d62e6d052c5d2638b08d480a720254ea09ff2d
child 533100 40ed41e7f6c51842ed59a962fc82dbffda331fb3
push id31523
push userbmo:bpostelnicu@mozilla.com
push dateThu, 06 Oct 2016 08:10:17 +0000
reviewersmystor
bugs1230311
milestone52.0a1
Bug 1230311 - clang-plugin - static analysis to enforce that if a method is overridden also base method is called. r?mystor MozReview-Commit-ID: 1JYzYZZOh3W
build/clang-plugin/clang-plugin.cpp
build/clang-plugin/tests/TestOverrideBaseCall.cpp
build/clang-plugin/tests/moz.build
--- a/build/clang-plugin/clang-plugin.cpp
+++ b/build/clang-plugin/clang-plugin.cpp
@@ -171,16 +171,29 @@ private:
     virtual void run(const MatchFinder::MatchResult &Result);
   };
 
   class SprintfLiteralChecker : public MatchFinder::MatchCallback {
   public:
     virtual void run(const MatchFinder::MatchResult &Result);
   };
 
+  class OverrideBaseCallChecker : public MatchFinder::MatchCallback {
+  public:
+    virtual void run(const MatchFinder::MatchResult &Result);
+  private:
+    void evaluateExpression(const Stmt *StmtExpr,
+        std::list<const CXXMethodDecl*> &MethodList);
+    void getRequiredBaseMethod(const CXXMethodDecl* Method,
+        std::list<const CXXMethodDecl*>& MethodsList);
+    void findBaseMethodCall(const CXXMethodDecl* Method,
+        std::list<const CXXMethodDecl*>& MethodsList);
+    bool isRequiredBaseMethod(const CXXMethodDecl *Method);
+  };
+
   ScopeChecker Scope;
   ArithmeticArgChecker ArithmeticArg;
   TrivialCtorDtorChecker TrivialCtorDtor;
   NaNExprChecker NaNExpr;
   NoAddRefReleaseOnReturnChecker NoAddRefReleaseOnReturn;
   RefCountedInsideLambdaChecker RefCountedInsideLambda;
   ExplicitOperatorBoolChecker ExplicitOperatorBool;
   NoDuplicateRefCntMemberChecker NoDuplicateRefCntMember;
@@ -189,16 +202,17 @@ private:
   NonMemMovableMemberChecker NonMemMovableMember;
   ExplicitImplicitChecker ExplicitImplicit;
   NoAutoTypeChecker NoAutoType;
   NoExplicitMoveConstructorChecker NoExplicitMoveConstructor;
   RefCountedCopyConstructorChecker RefCountedCopyConstructor;
   AssertAssignmentChecker AssertAttribution;
   KungFuDeathGripChecker KungFuDeathGrip;
   SprintfLiteralChecker SprintfLiteral;
+  OverrideBaseCallChecker OverrideBaseCall;
   MatchFinder AstMatcher;
 };
 
 namespace {
 
 std::string getDeclarationNamespace(const Decl *Declaration) {
   const DeclContext *DC =
       Declaration->getDeclContext()->getEnclosingNamespaceContext();
@@ -923,16 +937,23 @@ AST_MATCHER(CallExpr, isSnprintfLikeFunc
 
 AST_MATCHER(CXXRecordDecl, isLambdaDecl) {
   return Node.isLambda();
 }
 
 AST_MATCHER(QualType, isRefPtr) {
   return typeIsRefPtr(Node);
 }
+
+AST_MATCHER(CXXRecordDecl, hasBaseClasses) {
+  const CXXRecordDecl *Decl = Node.getCanonicalDecl();
+
+  // Must have definition and should inherit other classes
+  return Decl && Decl->hasDefinition() && Decl->getNumBases();
+}
 }
 }
 
 namespace {
 
 void CustomTypeAnnotation::dumpAnnotationReason(DiagnosticsEngine &Diag,
                                                 QualType T,
                                                 SourceLocation Loc) {
@@ -1261,16 +1282,19 @@ DiagnosticsMatcher::DiagnosticsMatcher()
         allOf(hasArgument(0, ignoringParenImpCasts(declRefExpr().bind("buffer"))),
                              anyOf(hasArgument(1, sizeOfExpr(hasIgnoringParenImpCasts(declRefExpr().bind("size")))),
                                    hasArgument(1, integerLiteral().bind("immediate")),
                                    hasArgument(1, declRefExpr(to(varDecl(hasType(isConstQualified()),
                                                                          hasInitializer(integerLiteral().bind("constant")))))))))
         .bind("funcCall"),
       &SprintfLiteral
   );
+
+  AstMatcher.addMatcher(cxxRecordDecl(hasBaseClasses()).bind("class"),
+      &OverrideBaseCall);
 }
 
 // These enum variants determine whether an allocation has occured in the code.
 enum AllocationVariety {
   AV_None,
   AV_Global,
   AV_Automatic,
   AV_Temporary,
@@ -1955,16 +1979,112 @@ void DiagnosticsMatcher::SprintfLiteralC
 
     if (Type->getSize().ule(Literal->getValue())) {
       Diag.Report(D->getLocStart(), ErrorID) << Name << Replacement;
       Diag.Report(D->getLocStart(), NoteID) << Name;
     }
   }
 }
 
+bool DiagnosticsMatcher::OverrideBaseCallChecker::isRequiredBaseMethod(
+    const CXXMethodDecl *Method) {
+  return MozChecker::hasCustomAnnotation(Method, "moz_required_base_method");
+}
+
+void DiagnosticsMatcher::OverrideBaseCallChecker::evaluateExpression(
+    const Stmt *StmtExpr, std::list<const CXXMethodDecl*> &MethodList) {
+  // Continue while we have methods in our list
+  if (!MethodList.size()) {
+    return;
+  }
+
+  if (auto MemberFuncCall = dyn_cast<CXXMemberCallExpr>(StmtExpr)) {
+    if (auto Method = dyn_cast<CXXMethodDecl>(
+        MemberFuncCall->getDirectCallee())) {
+      findBaseMethodCall(Method, MethodList);
+    }
+  }
+
+  for (auto S : StmtExpr->children()) {
+    if (S) {
+      evaluateExpression(S, MethodList);
+    }
+  }
+}
+
+void DiagnosticsMatcher::OverrideBaseCallChecker::getRequiredBaseMethod(
+    const CXXMethodDecl *Method,
+    std::list<const CXXMethodDecl*>& MethodsList) {
+
+  if (isRequiredBaseMethod(Method)) {
+    MethodsList.push_back(Method);
+  } else {
+    // Loop through all it's base methods.
+    for (auto BaseMethod : Method->overridden_methods()) {
+      getRequiredBaseMethod(BaseMethod, MethodsList);
+    }
+  }
+}
+
+void DiagnosticsMatcher::OverrideBaseCallChecker::findBaseMethodCall(
+    const CXXMethodDecl* Method,
+    std::list<const CXXMethodDecl*>& MethodsList) {
+
+  MethodsList.remove(Method);
+  // Loop also through all it's base methods;
+  for (auto baseMethod : Method->overridden_methods()) {
+    findBaseMethodCall(baseMethod, MethodsList);
+  }
+}
+
+void DiagnosticsMatcher::OverrideBaseCallChecker::run(
+    const MatchFinder::MatchResult &Result) {
+  DiagnosticsEngine &Diag = Result.Context->getDiagnostics();
+  unsigned OverrideBaseCallCheckID = Diag.getDiagnosticIDs()->getCustomDiagID(
+      DiagnosticIDs::Error,
+      "Method %0 must be called in all overrides, but is not called in "
+      "this override defined for class %1");
+  const CXXRecordDecl *Decl = Result.Nodes.getNodeAs<CXXRecordDecl>("class");
+
+  // Loop through the methods and look for the ones that are overridden.
+  for (auto Method : Decl->methods()) {
+    // If this method doesn't override other methods or it doesn't have a body,
+    // continue to the next declaration.
+    if (!Method->size_overridden_methods() || !Method->hasBody()) {
+      continue;
+    }
+
+    // Preferred the usage of list instead of vector in order to avoid
+    // calling erase-remove when deleting items
+    std::list<const CXXMethodDecl*> MethodsList;
+    // For each overridden method push it to a list if it meets our
+    // criteria
+    for (auto BaseMethod : Method->overridden_methods()) {
+      getRequiredBaseMethod(BaseMethod, MethodsList);
+    }
+
+    // If no method has been found then no annotation was used
+    // so checking is not needed
+    if (!MethodsList.size()) {
+      continue;
+    }
+
+    // Loop through the body of our method and search for calls to
+    // base methods
+    evaluateExpression(Method->getBody(), MethodsList);
+
+    // If list is not empty pop up errors
+    for (auto BaseMethod : MethodsList) {
+      Diag.Report(Method->getLocation(), OverrideBaseCallCheckID)
+          << BaseMethod->getQualifiedNameAsString()
+          << Decl->getName();
+    }
+  }
+}
+
 class MozCheckAction : public PluginASTAction {
 public:
   ASTConsumerPtr CreateASTConsumer(CompilerInstance &CI,
                                    StringRef FileName) override {
 #if CLANG_VERSION_FULL >= 306
     std::unique_ptr<MozChecker> Checker(llvm::make_unique<MozChecker>(CI));
     ASTConsumerPtr Other(Checker->getOtherConsumer());
 
new file mode 100644
--- /dev/null
+++ b/build/clang-plugin/tests/TestOverrideBaseCall.cpp
@@ -0,0 +1,175 @@
+#define MOZ_REQUIRED_BASE_METHOD __attribute__((annotate("moz_required_base_method")))
+
+class Base {
+public:
+  virtual void fo() MOZ_REQUIRED_BASE_METHOD {
+  }
+
+  virtual int foRet() MOZ_REQUIRED_BASE_METHOD {
+    return 0;
+  }
+};
+
+class BaseOne : public Base {
+public:
+  virtual void fo() MOZ_REQUIRED_BASE_METHOD {
+    Base::fo();
+  }
+};
+
+class BaseSecond : public Base {
+public:
+  virtual void fo() MOZ_REQUIRED_BASE_METHOD {
+   Base::fo();
+  }
+};
+
+class Deriv : public BaseOne, public BaseSecond {
+public:
+  void func() {
+  }
+
+  void fo() {
+    func();
+    BaseSecond::fo();
+    BaseOne::fo();
+  }
+};
+
+class DerivSimple : public Base {
+public:
+  void fo() { // expected-error {{Method Base::fo must be called in all overrides, but is not called in this override defined for class DerivSimple}}
+  }
+};
+
+class BaseVirtualOne : public virtual Base {
+};
+
+class BaseVirtualSecond: public virtual Base {
+};
+
+class DerivVirtual : public BaseVirtualOne, public BaseVirtualSecond {
+public:
+  void fo() {
+    Base::fo();
+  }
+};
+
+class DerivIf : public Base {
+public:
+  void fo() {
+    if (true) {
+      Base::fo();
+    }
+  }
+};
+
+class DerivIfElse : public Base {
+public:
+  void fo() {
+    if (true) {
+      Base::fo();
+    } else {
+      Base::fo();
+    }
+  }
+};
+
+class DerivFor : public Base {
+public:
+  void fo() {
+    for (int i = 0; i < 10; i++) {
+      Base::fo();
+    }
+  }
+};
+
+class DerivDoWhile : public Base {
+public:
+  void fo() {
+    do {
+      Base::fo();
+    } while(false);
+  }
+};
+
+class DerivWhile : public Base {
+public:
+  void fo() {
+    while (true) {
+      Base::fo();
+      break;
+    }
+  }
+};
+
+class DerivAssignment : public Base {
+public:
+  int foRet() {
+    return foRet();
+  }
+};
+
+class BaseOperator {
+private:
+  int value;
+public:
+  BaseOperator() : value(0) {
+  }
+  virtual BaseOperator& operator++() MOZ_REQUIRED_BASE_METHOD {
+    value++;
+    return *this;
+  }
+};
+
+class DerivOperatorErr : public BaseOperator {
+private:
+  int value;
+public:
+  DerivOperatorErr() : value(0) {
+  }
+  DerivOperatorErr& operator++() { // expected-error {{Method BaseOperator::operator++ must be called in all overrides, but is not called in this override defined for class DerivOperatorErr}}
+    value++;
+    return *this;
+  }
+};
+
+class DerivOperator : public BaseOperator {
+private:
+  int value;
+public:
+  DerivOperator() : value(0) {
+  }
+  DerivOperator& operator++() {
+    BaseOperator::operator++();
+    value++;
+    return *this;
+  }
+};
+
+class DerivPrime : public Base {
+public:
+  void fo() {
+    Base::fo();
+  }
+};
+
+class DerivSecondErr : public DerivPrime {
+public:
+  void fo() { // expected-error {{Method Base::fo must be called in all overrides, but is not called in this override defined for class DerivSecondErr}}
+  }
+};
+
+class DerivSecond : public DerivPrime {
+public:
+  void fo() {
+    Base::fo();
+  }
+};
+
+class DerivSecondIndirect : public DerivPrime {
+public:
+  void fo() {
+    DerivPrime::fo();
+  }
+};
--- a/build/clang-plugin/tests/moz.build
+++ b/build/clang-plugin/tests/moz.build
@@ -27,16 +27,17 @@ SOURCES += [
     'TestNoAutoType.cpp',
     'TestNoDuplicateRefCntMember.cpp',
     'TestNoExplicitMoveConstructor.cpp',
     'TestNonHeapClass.cpp',
     'TestNonMemMovable.cpp',
     'TestNonMemMovableStd.cpp',
     'TestNonTemporaryClass.cpp',
     'TestNoRefcountedInsideLambdas.cpp',
+    'TestOverrideBaseCall.cpp',
     'TestRefCountedCopyConstructor.cpp',
     'TestSprintfLiteral.cpp',
     'TestStackClass.cpp',
     'TestTrivialCtorDtor.cpp',
 ]
 
 DISABLE_STL_WRAPPING = True
 NO_VISIBILITY_FLAGS = True