Bug 1345897 - Use a separate error function for sentinel errors. r=kanru draft
authorAndrew McCreight <continuation@gmail.com>
Thu, 09 Mar 2017 13:37:55 -0800
changeset 496889 dff3382f4935d645188a01660c3769ff6e6c64da
parent 496888 bbf1ba5db00b2b32af4b884b681c50619f1b58ab
child 548739 26e0acda34dadb36aa4e4feee08c239c382b489a
push id48732
push userbmo:continuation@gmail.com
push dateFri, 10 Mar 2017 21:13:47 +0000
reviewerskanru
bugs1345897
milestone55.0a1
Bug 1345897 - Use a separate error function for sentinel errors. r=kanru Using a separate error function will distinguish mismatched sentinels from other errors, such as array length problems. MozReview-Commit-ID: Gl8swNhqLns
ipc/glue/ProtocolUtils.cpp
ipc/glue/ProtocolUtils.h
ipc/ipdl/ipdl/lower.py
--- a/ipc/glue/ProtocolUtils.cpp
+++ b/ipc/glue/ProtocolUtils.cpp
@@ -342,23 +342,31 @@ MismatchedActorTypeError(const char* aAc
 
 void
 UnionTypeReadError(const char* aUnionName)
 {
   nsPrintfCString message("error deserializing type of union %s", aUnionName);
   NS_RUNTIMEABORT(message.get());
 }
 
-void ArrayLengthReadError(const char* aElementName)
+void
+ArrayLengthReadError(const char* aElementName)
 {
   nsPrintfCString message("error deserializing length of %s[]", aElementName);
   NS_RUNTIMEABORT(message.get());
 }
 
 void
+SentinelReadError(const char* aClassName)
+{
+  nsPrintfCString message("incorrect sentinel when reading %s", aClassName);
+  NS_RUNTIMEABORT(message.get());
+}
+
+void
 TableToArray(const nsTHashtable<nsPtrHashKey<void>>& aTable,
              nsTArray<void*>& aArray)
 {
   uint32_t i = 0;
   void** elements = aArray.AppendElements(aTable.Count());
   for (auto iter = aTable.ConstIter(); !iter.Done(); iter.Next()) {
     elements[i] = iter.Get()->GetKey();
     ++i;
--- a/ipc/glue/ProtocolUtils.h
+++ b/ipc/glue/ProtocolUtils.h
@@ -484,16 +484,19 @@ MOZ_NEVER_INLINE void
 MismatchedActorTypeError(const char* aActorDescription);
 
 MOZ_NEVER_INLINE void
 UnionTypeReadError(const char* aUnionName);
 
 MOZ_NEVER_INLINE void
 ArrayLengthReadError(const char* aElementName);
 
+MOZ_NEVER_INLINE void
+SentinelReadError(const char* aElementName);
+
 struct PrivateIPDLInterface {};
 
 nsresult
 Bridge(const PrivateIPDLInterface&,
        MessageChannel*, base::ProcessId, MessageChannel*, base::ProcessId,
        ProtocolId, ProtocolId);
 
 bool
--- a/ipc/ipdl/ipdl/lower.py
+++ b/ipc/ipdl/ipdl/lower.py
@@ -401,16 +401,21 @@ def _arrayLengthReadError(elementname):
         ExprCall(ExprVar('mozilla::ipc::ArrayLengthReadError'),
                  args=[ ExprLiteral.String(elementname) ]))
 
 def _unionTypeReadError(unionname):
     return StmtExpr(
         ExprCall(ExprVar('mozilla::ipc::UnionTypeReadError'),
                  args=[ ExprLiteral.String(unionname) ]))
 
+def _sentinelReadError(classname):
+    return StmtExpr(
+        ExprCall(ExprVar('mozilla::ipc::SentinelReadError'),
+                 args=[ ExprLiteral.String(classname) ]))
+
 def _killProcess(pid):
     return ExprCall(
         ExprVar('base::KillProcess'),
         args=[ pid,
                # XXX this is meaningless on POSIX
                ExprVar('base::PROCESS_END_KILLED_BY_USER'),
                ExprLiteral.FALSE ])
 
@@ -462,16 +467,21 @@ def errfnRead(msg):
     return [ _fatalError(msg), StmtReturn.FALSE ]
 
 def errfnArrayLength(elementname):
     return [ _arrayLengthReadError(elementname), StmtReturn.FALSE ]
 
 def errfnUnionType(unionname):
     return [ _unionTypeReadError(unionname), StmtReturn.FALSE ]
 
+def errfnSentinel(rvalue=ExprLiteral.FALSE):
+    def inner(msg):
+        return [ _sentinelReadError(msg), StmtReturn(rvalue) ]
+    return inner
+
 def _destroyMethod():
     return ExprVar('ActorDestroy')
 
 class _DestroyReason:
     @staticmethod
     def Type():  return Type('ActorDestroyReason')
 
     Deletion = ExprVar('Deletion')
@@ -3374,27 +3384,29 @@ class _GenerateProtocolActorCode(ipdl.as
         forread = StmtFor(init=ExprAssn(Decl(Type.UINT32, ivar.name),
                                         ExprLiteral.ZERO),
                           cond=ExprBinary(ivar, '<', lenvar),
                           update=ExprPrefixUnop(ivar, '++'))
         forread.addstmt(
             self.checkedRead(eltipdltype, ExprAddrOf(ExprIndex(elemsvar, ivar)),
                              msgvar, itervar, errfnRead,
                              '\'' + eltipdltype.name() + '[i]\'',
-                             sentinelKey=arraytype.name()))
+                             sentinelKey=arraytype.name(),
+                             errfnSentinel=errfnSentinel()))
         appendstmt = StmtDecl(Decl(directtype, elemsvar.name),
                               init=ExprCall(ExprSelect(favar, '.', 'AppendElements'),
                                             args=[ lenvar ]))
         read.addstmts([
             StmtDecl(Decl(_cxxArrayType(_cxxBareType(arraytype.basetype, self.side)), favar.name)),
             StmtDecl(Decl(Type.UINT32, lenvar.name)),
             self.checkedRead(None, ExprAddrOf(lenvar),
                              msgvar, itervar, errfnArrayLength,
                              [ arraytype.name() ],
-                             sentinelKey=('length', arraytype.name())),
+                             sentinelKey=('length', arraytype.name()),
+                             errfnSentinel=errfnSentinel()),
             Whitespace.NL,
             appendstmt,
             forread,
             StmtExpr(_callCxxSwapArrayElements(var, favar, '->')),
             StmtReturn.TRUE
         ])
 
         self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ])
@@ -3519,17 +3531,18 @@ class _GenerateProtocolActorCode(ipdl.as
             return ExprCall(f.getMethod(thisexpr=var, sel=sel))
 
         for f in sd.fields:
             desc = '\'' + f.getMethod().name + '\' (' + f.ipdltype.name() + \
                    ') member of \'' + intype.name + '\''
             writefield = self.checkedWrite(f.ipdltype, get('.', f), msgvar, sentinelKey=f.basename)
             readfield = self.checkedRead(f.ipdltype,
                                          ExprAddrOf(get('->', f)),
-                                         msgvar, itervar, errfnRead, desc, sentinelKey=f.basename)
+                                         msgvar, itervar, errfnRead, desc,
+                                         sentinelKey=f.basename, errfnSentinel=errfnSentinel())
             if f.special and f.side != self.side:
                 writefield = Whitespace(
                     "// skipping actor field that's meaningless on this side\n", indent=1)
                 readfield = Whitespace(
                     "// skipping actor field that's meaningless on this side\n", indent=1)
             write.addstmt(writefield)
             read.addstmt(readfield)
 
@@ -3583,17 +3596,18 @@ class _GenerateProtocolActorCode(ipdl.as
                 ct = c.bareType()
                 readcase.addstmts([
                     StmtDecl(Decl(ct, tmpvar.name), init=c.defaultValue()),
                     StmtExpr(ExprAssn(ExprDeref(var), tmpvar)),
                     self.checkedRead(
                         c.ipdltype,
                         ExprAddrOf(ExprCall(ExprSelect(var, '->',
                                                        c.getTypeName()))),
-                        msgvar, itervar, errfnRead, 'Union type', sentinelKey=origenum),
+                        msgvar, itervar, errfnRead, 'Union type',
+                        sentinelKey=origenum, errfnSentinel=errfnSentinel()),
                     StmtReturn(ExprLiteral.TRUE)
                 ])
 
             readswitch.addcase(caselabel, readcase)
 
         unknowntype = 'unknown union type'
         writeswitch.addcase(DefaultLabel(),
                             StmtBlock([ _fatalError(unknowntype),
@@ -3612,17 +3626,17 @@ class _GenerateProtocolActorCode(ipdl.as
 
         read = MethodDefn(self.readMethodDecl(outtype, var))
         read.addstmts([
             uniontdef,
             StmtDecl(Decl(Type.INT, typevar.name)),
             self.checkedRead(
                 None, ExprAddrOf(typevar), msgvar, itervar, errfnUnionType,
                 [ uniontype.name() ],
-                sentinelKey=uniontype.name()),
+                sentinelKey=uniontype.name(), errfnSentinel=errfnSentinel()),
             Whitespace.NL,
             readswitch,
         ])
 
         self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ])
 
 
     def writeMethodDecl(self, intype, var, template=None):
@@ -3779,17 +3793,18 @@ class _GenerateProtocolActorCode(ipdl.as
             + self.genVerifyMessage(md.decl.type.verify, md.params,
                                     errfnSendCtor, ExprVar('msg__'))
             + sendstmts
             + self.failCtorIf(md, ExprNot(sendok)))
 
         def errfnCleanupCtor(msg):
             return self.failCtorIf(md, ExprLiteral.TRUE)
         stmts = self.deserializeReply(
-            md, ExprAddrOf(replyvar), self.side, errfnCleanupCtor)
+            md, ExprAddrOf(replyvar), self.side,
+            errfnCleanupCtor, errfnSentinel(ExprLiteral.NULL))
         method.addstmts(stmts + [ StmtReturn(actor.var()) ])
 
         return method
 
 
     def ctorPrologue(self, md, errfn=ExprLiteral.NULL, idexpr=None):
         actordecl = md.actorDecl()
         actorvar = actordecl.var()
@@ -3886,17 +3901,18 @@ class _GenerateProtocolActorCode(ipdl.as
             stmts
             + self.genVerifyMessage(md.decl.type.verify, md.params,
                                     errfnSendDtor, ExprVar('msg__'))
             + [ Whitespace.NL,
                 StmtDecl(Decl(Type('Message'), replyvar.name)) ]
             + sendstmts)
 
         destmts = self.deserializeReply(
-            md, ExprAddrOf(replyvar), self.side, errfnSend, actorvar)
+            md, ExprAddrOf(replyvar), self.side, errfnSend,
+            errfnSentinel(), actorvar)
         ifsendok = StmtIf(ExprLiteral.FALSE)
         ifsendok.addifstmts(destmts)
         ifsendok.addifstmts([ Whitespace.NL,
                               StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, '&=')) ])
 
         method.addstmt(ifsendok)
 
         if self.protocol.decl.type.hasReentrantDelete:
@@ -3949,17 +3965,17 @@ class _GenerateProtocolActorCode(ipdl.as
         msgvar, serstmts = self.makeMessage(md, errfnSend, fromActor)
         replyvar = self.replyvar
 
         sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
         failif = StmtIf(ExprNot(sendok))
         failif.addifstmt(StmtReturn.FALSE)
 
         desstmts = self.deserializeReply(
-            md, ExprAddrOf(replyvar), self.side, errfnSend)
+            md, ExprAddrOf(replyvar), self.side, errfnSend, errfnSentinel())
 
         method.addstmts(
             serstmts
             + self.genVerifyMessage(md.decl.type.verify, md.params, errfnSend,
                                     ExprVar('msg__'))
             + [ Whitespace.NL,
                 StmtDecl(Decl(Type('Message'), replyvar.name)) ]
             + sendstmts
@@ -3972,17 +3988,18 @@ class _GenerateProtocolActorCode(ipdl.as
 
 
     def genCtorRecvCase(self, md):
         lbl = CaseLabel(md.pqMsgId())
         case = StmtBlock()
         actorvar = md.actorDecl().var()
         actorhandle = self.handlevar
 
-        stmts = self.deserializeMessage(md, self.side, errfnRecv)
+        stmts = self.deserializeMessage(md, self.side, errfnRecv,
+                                        errfnSent=errfnSentinel(_Result.ValuError))
 
         idvar, saveIdStmts = self.saveActorId(md)
         case.addstmts(
             stmts
             + self.transition(md)
             + [ StmtDecl(Decl(r.bareType(self.side), r.var().name))
                 for r in md.returns ]
             # alloc the actor, register it under the foreign ID
@@ -4002,17 +4019,18 @@ class _GenerateProtocolActorCode(ipdl.as
 
         return lbl, case
 
 
     def genDtorRecvCase(self, md):
         lbl = CaseLabel(md.pqMsgId())
         case = StmtBlock()
 
-        stmts = self.deserializeMessage(md, self.side, errfnRecv)
+        stmts = self.deserializeMessage(md, self.side, errfnRecv,
+                                        errfnSent=errfnSentinel(_Result.ValuError))
 
         idvar, saveIdStmts = self.saveActorId(md)
         case.addstmts(
             stmts
             + self.transition(md)
             + [ StmtDecl(Decl(r.bareType(self.side), r.var().name))
                 for r in md.returns ]
             + self.invokeRecvHandler(md, implicit=0)
@@ -4028,17 +4046,18 @@ class _GenerateProtocolActorCode(ipdl.as
 
         return lbl, case
 
 
     def genRecvCase(self, md):
         lbl = CaseLabel(md.pqMsgId())
         case = StmtBlock()
 
-        stmts = self.deserializeMessage(md, self.side, errfn=errfnRecv)
+        stmts = self.deserializeMessage(md, self.side, errfn=errfnRecv,
+                                        errfnSent=errfnSentinel(_Result.ValuError))
 
         idvar, saveIdStmts = self.saveActorId(md)
         case.addstmts(
             stmts
             + self.transition(md)
             + [ StmtDecl(Decl(r.bareType(self.side), r.var().name))
                 for r in md.returns ]
             + saveIdStmts
@@ -4155,17 +4174,17 @@ class _GenerateProtocolActorCode(ipdl.as
 
         if reply:
             stmts.append(StmtExpr(ExprCall(
                 ExprSelect(var, '->', 'set_reply'))))
 
         return stmts + [ Whitespace.NL ]
 
 
-    def deserializeMessage(self, md, side, errfn):
+    def deserializeMessage(self, md, side, errfn, errfnSent):
         msgvar = self.msgvar
         itervar = self.itervar
         msgexpr = ExprAddrOf(msgvar)
         isctor = md.decl.type.isCtor()
         stmts = ([
             self.logMessage(md, msgexpr, 'Received ',
                             receiving=True),
             self.profilerLabel(md),
@@ -4180,54 +4199,54 @@ class _GenerateProtocolActorCode(ipdl.as
             # return the raw actor handle so that its ID can be used
             # to construct the "real" actor
             handlevar = self.handlevar
             handletype = Type('ActorHandle')
             decls = [ StmtDecl(Decl(handletype, handlevar.name)) ]
             reads = [ self.checkedRead(None, ExprAddrOf(handlevar), msgexpr,
                                        ExprAddrOf(self.itervar),
                                        errfn, "'%s'" % handletype.name,
-                                       sentinelKey='actor') ]
+                                       sentinelKey='actor', errfnSentinel=errfnSent) ]
             start = 1
 
         stmts.extend((
             [ StmtDecl(Decl(_iterType(ptr=0), self.itervar.name),
                      init=ExprCall(ExprVar('PickleIterator'),
                                    args=[ msgvar ])) ]
             + decls + [ StmtDecl(Decl(p.bareType(side), p.var().name))
                       for p in md.params ]
             + [ Whitespace.NL ]
             + reads + [ self.checkedRead(p.ipdltype, ExprAddrOf(p.var()),
                                          msgexpr, ExprAddrOf(itervar),
                                          errfn, "'%s'" % p.bareType(side).name,
-                                         sentinelKey=p.name)
+                                         sentinelKey=p.name, errfnSentinel=errfnSent)
                         for p in md.params[start:] ]
             + [ self.endRead(msgvar, itervar) ]))
 
         return stmts
 
 
-    def deserializeReply(self, md, replyexpr, side, errfn, actor=None):
+    def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None):
         stmts = [ Whitespace.NL,
                    self.logMessage(md, replyexpr,
                                    'Received reply ', actor, receiving=True) ]
         if 0 == len(md.returns):
             return stmts
 
         itervar = self.itervar
         stmts.extend(
             [ Whitespace.NL,
               StmtDecl(Decl(_iterType(ptr=0), itervar.name),
                        init=ExprCall(ExprVar('PickleIterator'),
                                      args=[ self.replyvar ])) ]
             + [ self.checkedRead(r.ipdltype, r.var(),
                                  ExprAddrOf(self.replyvar),
                                  ExprAddrOf(self.itervar),
                                  errfn, "'%s'" % r.bareType(side).name,
-                                 sentinelKey=r.name)
+                                 sentinelKey=r.name, errfnSentinel=errfnSentinel)
                 for r in md.returns ]
             + [ self.endRead(self.replyvar, itervar) ])
 
         return stmts
 
     def sendAsync(self, md, msgexpr, actor=None):
         sendok = ExprVar('sendok__')
         return (
@@ -4361,35 +4380,36 @@ class _GenerateProtocolActorCode(ipdl.as
         if actor is not None:  stateexpr = _actorState(actor)
         else:                  stateexpr = self.protocol.stateVar()
 
         msgid = md.pqMsgId() if not reply else md.pqReplyId()
         return [ StmtExpr(ExprCall(ExprVar(self.protocol.name +'::Transition'),
                                    args=[ ExprVar(msgid),
                                    ExprAddrOf(stateexpr) ])) ]
 
-    def checkedRead(self, ipdltype, expr, msgexpr, iterexpr, errfn, paramtype, sentinelKey, sentinel=True):
+    def checkedRead(self, ipdltype, expr, msgexpr, iterexpr, errfn, paramtype, sentinelKey, errfnSentinel, sentinel=True):
         ifbad = StmtIf(ExprNot(self.read(ipdltype, expr, msgexpr, iterexpr)))
         if isinstance(paramtype, list):
             errorcall = errfn(*paramtype)
+            senterrorcall = errfnSentinel(*paramtype)
         else:
             errorcall = errfn('Error deserializing ' + paramtype)
+            senterrorcall = errfnSentinel('Error deserializing ' + paramtype)
         ifbad.addifstmts(errorcall)
 
         block = Block()
         block.addstmt(ifbad)
 
         if sentinel:
             assert sentinelKey
-
             block.addstmt(Whitespace('// Sentinel = ' + repr(sentinelKey) + '\n', indent=1))
             read = ExprCall(ExprSelect(msgexpr, '->', 'ReadSentinel'),
                                   args=[ iterexpr, ExprLiteral.Int(hashfunc(sentinelKey)) ])
             ifsentinel = StmtIf(ExprNot(read))
-            ifsentinel.addifstmts(errorcall)
+            ifsentinel.addifstmts(senterrorcall)
             block.addstmt(ifsentinel)
 
         return block
 
     def endRead(self, msgexpr, iterexpr):
         return StmtExpr(ExprCall(ExprSelect(msgexpr, '.', 'EndRead'),
                                  args=[ iterexpr ]))