Bug 1457536 - refactor Transition in IPC to make it more amenable to fuzing; r?froydnj draft
authorAlex Gaynor <agaynor@mozilla.com>
Thu, 03 May 2018 12:08:48 -0400
changeset 791232 7ada2d50e06dc5dd6154118e449aedbc2b39ca4c
parent 791070 ce588e44f41599808a830db23d190d1ca474a781
push id108749
push userbmo:agaynor@mozilla.com
push dateThu, 03 May 2018 20:10:01 +0000
reviewersfroydnj
bugs1457536
milestone61.0a1
Bug 1457536 - refactor Transition in IPC to make it more amenable to fuzing; r?froydnj Instead of crashing the process inside Transition on a bad state transition, propagate an error up the stack (and crash higher up). MozReview-Commit-ID: JJmAeq6xSfe
ipc/glue/ProtocolUtils.cpp
ipc/glue/ProtocolUtils.h
ipc/ipdl/ipdl/lower.py
--- a/ipc/glue/ProtocolUtils.cpp
+++ b/ipc/glue/ProtocolUtils.cpp
@@ -349,57 +349,55 @@ ArrayLengthReadError(const char* aElemen
 }
 
 void
 SentinelReadError(const char* aClassName)
 {
   MOZ_CRASH_UNSAFE_PRINTF("incorrect sentinel when reading %s", aClassName);
 }
 
-void
+bool
 StateTransition(bool aIsDelete, State* aNext)
 {
   switch (*aNext) {
     case State::Null:
       if (aIsDelete) {
         *aNext = State::Dead;
       }
       break;
     case State::Dead:
-      LogicError("__delete__()d actor");
-      break;
+      return false;
     default:
-      LogicError("corrupted actor state");
-      break;
+      return false;
   }
+  return true;
 }
 
-void
+bool
 ReEntrantDeleteStateTransition(bool aIsDelete,
                                bool aIsDeleteReply,
                                ReEntrantDeleteState* aNext)
 {
   switch (*aNext) {
     case ReEntrantDeleteState::Null:
       if (aIsDelete) {
         *aNext = ReEntrantDeleteState::Dying;
       }
       break;
     case ReEntrantDeleteState::Dead:
-      LogicError("__delete__()d actor");
-      break;
+      return false;
     case ReEntrantDeleteState::Dying:
       if (aIsDeleteReply) {
         *aNext = ReEntrantDeleteState::Dead;
       }
       break;
     default:
-      LogicError("corrupted actor state");
-      break;
+      return false;
   }
+  return true;
 }
 
 void
 TableToArray(const nsTHashtable<nsPtrHashKey<void>>& aTable,
              nsTArray<void*>& aArray)
 {
   uint32_t i = 0;
   void** elements = aArray.AppendElements(aTable.Count());
--- a/ipc/glue/ProtocolUtils.h
+++ b/ipc/glue/ProtocolUtils.h
@@ -752,28 +752,28 @@ void AnnotateSystemError();
 
 enum class State
 {
   Dead,
   Null,
   Start = Null
 };
 
-void
+bool
 StateTransition(bool aIsDelete, State* aNext);
 
 enum class ReEntrantDeleteState
 {
   Dead,
   Null,
   Dying,
   Start = Null,
 };
 
-void
+bool
 ReEntrantDeleteStateTransition(bool aIsDelete,
                                bool aIsDeleteReply,
                                ReEntrantDeleteState* aNext);
 
 /**
  * An endpoint represents one end of a partially initialized IPDL channel. To
  * set up a new top-level protocol:
  *
--- a/ipc/ipdl/ipdl/lower.py
+++ b/ipc/ipdl/ipdl/lower.py
@@ -405,16 +405,21 @@ def errfnRecv(msg, errcode=_Result.ValuE
 def errfnSentinel(rvalue=ExprLiteral.FALSE):
     def inner(msg):
         return [ _sentinelReadError(msg), StmtReturn(rvalue) ]
     return inner
 
 def _destroyMethod():
     return ExprVar('ActorDestroy')
 
+def errfnUnreachable(msg):
+    return [
+        _logicError(msg)
+    ]
+
 class _DestroyReason:
     @staticmethod
     def Type():  return Type('ActorDestroyReason')
 
     Deletion = ExprVar('Deletion')
     AncestorDeletion = ExprVar('AncestorDeletion')
     NormalShutdown = ExprVar('NormalShutdown')
     AbnormalShutdown = ExprVar('AbnormalShutdown')
@@ -3908,17 +3913,17 @@ class _GenerateProtocolActorCode(ipdl.as
         ifsendok = StmtIf(ExprLiteral.FALSE)
         ifsendok.addifstmts(destmts)
         ifsendok.addifstmts([ Whitespace.NL,
                               StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, '&=')) ])
 
         method.addstmt(ifsendok)
 
         if self.protocol.decl.type.hasReentrantDelete:
-            method.addstmts(self.transition(md, actor.var(), reply=True))
+            method.addstmts(self.transition(md, actor.var(), reply=True, errorfn=errfnUnreachable))
 
         method.addstmts(
             self.dtorEpilogue(md, actor.var())
             + [ Whitespace.NL, StmtReturn(sendok) ])
 
         return method
 
     def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion):
@@ -4053,17 +4058,17 @@ class _GenerateProtocolActorCode(ipdl.as
         actorhandle = self.handlevar
 
         stmts = self.deserializeMessage(md, self.side, errfnRecv,
                                         errfnSent=errfnSentinel(_Result.ValuError))
 
         idvar, saveIdStmts = self.saveActorId(md)
         case.addstmts(
             stmts
-            + self.transition(md)
+            + self.transition(md, errorfn=errfnRecv)
             + [ StmtDecl(Decl(r.bareType(self.side), r.var().name))
                 for r in md.returns ]
             # alloc the actor, register it under the foreign ID
             + [ StmtExpr(ExprAssn(
                 actorvar,
                 self.callAllocActor(md, retsems='in', side=self.side))) ]
             + self.ctorPrologue(md, errfn=_Result.ValuError,
                                 idexpr=_actorHId(actorhandle))
@@ -4084,17 +4089,17 @@ class _GenerateProtocolActorCode(ipdl.as
         case = StmtBlock()
 
         stmts = self.deserializeMessage(md, self.side, errfnRecv,
                                         errfnSent=errfnSentinel(_Result.ValuError))
 
         idvar, saveIdStmts = self.saveActorId(md)
         case.addstmts(
             stmts
-            + self.transition(md)
+            + self.transition(md, errorfn=errfnRecv)
             + [ StmtDecl(Decl(r.bareType(self.side), r.var().name))
                 for r in md.returns ]
             + self.invokeRecvHandler(md, implicit=0)
             + [ Whitespace.NL ]
             + saveIdStmts
             + self.makeReply(md, errfnRecv, routingId=idvar)
             + [ Whitespace.NL ]
             + self.genVerifyMessage(md.decl.type.verify, md.returns, errfnRecv,
@@ -4115,17 +4120,17 @@ class _GenerateProtocolActorCode(ipdl.as
 
         idvar, saveIdStmts = self.saveActorId(md)
         declstmts = [ StmtDecl(Decl(r.bareType(self.side), r.var().name))
                       for r in md.returns ]
         if md.decl.type.isAsync() and md.returns:
             declstmts = self.makeResolver(md, errfnRecv, routingId=idvar)
         case.addstmts(
             stmts
-            + self.transition(md)
+            + self.transition(md, errorfn=errfnRecv)
             + saveIdStmts
             + declstmts
             + self.invokeRecvHandler(md)
             + [ Whitespace.NL ]
             + self.makeReply(md, errfnRecv, routingId=idvar)
             + self.genVerifyMessage(md.decl.type.verify, md.returns, errfnRecv,
                                     self.replyvar)
             + [ StmtReturn(_Result.Processed) ])
@@ -4443,17 +4448,17 @@ class _GenerateProtocolActorCode(ipdl.as
     def sendAsync(self, md, msgexpr, actor=None):
         sendok = ExprVar('sendok__')
         resolvefn = ExprVar('aResolve')
         rejectfn = ExprVar('aReject')
 
         sendargs = [ msgexpr ]
         stmts = [ Whitespace.NL,
                   self.logMessage(md, msgexpr, 'Sending ', actor),
-                  self.profilerLabel(md) ] + self.transition(md, actor)
+                  self.profilerLabel(md) ] + self.transition(md, actor, errorfn=errfnUnreachable)
         stmts.append(Whitespace.NL)
 
         # Generate the actual call expression.
         send = ExprSelect(self.protocol.callGetChannel(actor), '->', 'Send')
         if md.returns:
             stmts.append(StmtExpr(ExprCall(send, args=[ msgexpr,
                                                         ExprVar('this'),
                                                         ExprMove(resolvefn),
@@ -4468,17 +4473,17 @@ class _GenerateProtocolActorCode(ipdl.as
 
     def sendBlocking(self, md, msgexpr, replyexpr, actor=None):
         sendok = ExprVar('sendok__')
         return (
             sendok,
             ([ Whitespace.NL,
                self.logMessage(md, msgexpr, 'Sending ', actor),
                self.profilerLabel(md) ]
-            + self.transition(md, actor)
+            + self.transition(md, actor, errorfn=errfnUnreachable)
             + [ Whitespace.NL,
                 StmtDecl(Decl(Type.BOOL, sendok.name)),
                 StmtBlock([
                     StmtExpr(ExprCall(ExprVar('AUTO_PROFILER_TRACING'),
                              [ ExprLiteral.String("IPC"),
                                ExprLiteral.String(self.protocol.name + "::" + md.prettyMsgName()) ])),
                     StmtExpr(ExprAssn(sendok,
                                       ExprCall(ExprSelect(self.protocol.callGetChannel(actor),
@@ -4622,17 +4627,17 @@ class _GenerateProtocolActorCode(ipdl.as
             # only save the ID if we're actually going to use it, to
             # avoid unused-variable warnings
             saveIdStmts = [ StmtDecl(Decl(_actorIdType(), idvar.name),
                                      self.protocol.routingId()) ]
         else:
             saveIdStmts = [ ]
         return idvar, saveIdStmts
 
-    def transition(self, md, actor=None, reply=False):
+    def transition(self, md, actor=None, reply=False, errorfn=None):
         msgid = md.pqMsgId() if not reply else md.pqReplyId()
         args = [
             ExprVar('true' if _deleteId().name == msgid else 'false'),
         ]
         if self.protocol.decl.type.hasReentrantDelete:
             function = 'ReEntrantDeleteStateTransition'
             args.append(
                 ExprVar('true' if _deleteReplyId().name == msgid else 'false'),
@@ -4642,19 +4647,19 @@ class _GenerateProtocolActorCode(ipdl.as
 
         if actor is not None:
             stateexpr = _actorState(actor)
         else:
             stateexpr = self.protocol.stateVar()
 
         args.append(ExprAddrOf(stateexpr))
 
-        return [
-            StmtExpr(ExprCall(ExprVar(function), args=args))
-        ]
+        ifstmt = StmtIf(ExprNot(ExprCall(ExprVar(function), args=args)))
+        ifstmt.addifstmts(errorfn('Transition error'))
+        return [ifstmt]
 
     def endRead(self, msgexpr, iterexpr):
         msgtype = ExprCall(ExprSelect(msgexpr, '.', 'type'), [ ])
         return StmtExpr(ExprCall(ExprSelect(msgexpr, '.', 'EndRead'),
                                  args=[ iterexpr, msgtype ]))
 
 class _GenerateProtocolParentCode(_GenerateProtocolActorCode):
     def __init__(self):