from PIPRes.exceptions import CipresException from PIPRes.wrap.idl import * from PIPRes.cipres_types import toIDLTree, toIDLMatrix, CipresTree, CipresDiscreteMatrix, reorderMatrix, mapLeaves from PIPRes.util.cipres import cipresGetLogger from sets import Set, ImmutableSet from PIPRes.wrap.stub import PIPResInterfaceWrapper from PIPRes.tree import BadTreeDefError import copy # Client wrappers that check for preconditions and raise exceptions _LOG = cipresGetLogger('pipres.wrap.debug_stub') def validateTreeDecomposeSets(fullLeafSet, returnedSets, treatAsArg = True, recIDCMConstraints = True): '''Raises CipresException if the returnedSets have deleted taxa. If recIDCMConstraints is True, the intersection of all sets must be at least 4 taxa.''' majorCode = treatAsArg and CipresException.MAJOR_BAD_ARG or CipresException.MAJOR_BAD_RETURN inLeafSet = Set(fullLeafSet) setOfSets = [Set(i) for i in returnedSets] nSubSets = len(setOfSets) if nSubSets == 0: raise CipresException(majorCode, message = 'No Leaf Sets') if nSubSets == 1: unionSet = setOfSets[0] else: firstReturned = setOfSets[0] if recIDCMConstraints: interSet = Set(firstReturned.intersection(setOfSets[1])) unionSet = firstReturned.union(setOfSets[1]) i = 0 j = 1 if recIDCMConstraints and len(interSet) < 4: msg = 'Sets %d, %s, and %d, %s, have an intersection, %s, of fewer than 3 taxa' % (i, str(firstReturned), j, str(setOfSets[1]), str(interSet)) raise CipresException(majorCode, message = msg) for j in range(2, nSubSets): currSet = setOfSets[j] unionSet.union_update(currSet) if recIDCMConstraints: interSet.intersection_update(currSet) if len(interSet) < 4: msg = 'Set intersection for sets 0 - %d contains only %d taxa (%s)' %(j, len(interSet), str(interSet)) if unionSet != inLeafSet: omitted = str(inLeafSet.difference(unionSet)) if unionSet.issubset(inLeafSet): raise CipresException(majorCode, message = 'Taxa (%s) omitted.' % omitted) added = str(unionSet.difference(inLeafSet)) if inLeafSet.issubset(unionSet): raise CipresException(majorCode, message = 'Taxa (%s) added.' % added) raise CipresException(majorCode, message = 'Taxa (%s) omitted and taxa (%s) added' % (omitted, added)) return True def validateTree(tree, **kwargs): majorCode = kwargs.get('treatAsArg', True) and CipresException.MAJOR_BAD_ARG or CipresException.MAJOR_BAD_RETURN try: t = CipresTree(tree) except BadTreeDefError, e: raise CipresException(majorCode, CipresException.MINOR_TREE, str(e)) origLeafSet = tree.m_leafSet if len(origLeafSet) > 0 and origLeafSet != t.m_leafSet: raise CipresException(majorCode, CipresException.MINOR_TREE, 'm_leafSet is not empty or valid') if kwargs.get('hasEdgeLengths', False) and not t.hasEdgeLengths: raise CipresException(majorCode, CipresException.MINOR_TREE, 'tree %s does not have edge lengths' % str(t)) return True def validateMatrix(tree, **kwargs): majorCode = kwargs.get('treatAsArg', True) and CipresException.MAJOR_BAD_ARG or CipresException.MAJOR_BAD_RETURN _LOG.warn('validateMatrix is not actually checking anything right now.') class DebugTreeDecomposeWrapper(PIPResInterfaceWrapper): def __init__(self, objRef): PIPResInterfaceWrapper.__init__(self, objRef) def leafSetDecompose(self, inTree): try: _LOG.info('Validating argument to leafSetDecompose') tree = toIDLTree(inTree) validateTree(tree, treatAsArg = True, hasEdgeLengths = True) ls = tree.getLeafSet() if len(ls) < 4: CipresException(CipresException.MAJOR_BAD_ARG, CipresException.MINOR_TREE, 'Tree with only %d taxa sent to decomposition' % len(ls)) returned = self.objRef.leafSetDecompose(tree) _LOG.info('Validating return of leafSetDecompose') validateTreeDecomposeSets(ls, returned, treatAsArg = True, recIDCMConstraints = True) except CipresException, ex: raisePyExcepFromCipresExcep(ex) return returned class DebugTreeMergeWrapper(PIPResInterfaceWrapper): def __init__(self, objRef): PIPResInterfaceWrapper.__init__(self, objRef) def mergeTrees(self, trees): try: _LOG.info('Validating argument to mergeTrees') treesInIDL = map(toIDLTree, trees) fullLeafSet = Set() for t in treesInIDL: validateTree(t, treatAsArg = True) ns = Set(t.getLeafSet()) fullLeafSet.union_update(ns) returned = self.objRef.mergeTrees(treesInIDL) _LOG.info('Validating returned tree from mergeTrees') validateTree(returned, treatAsArg = False) returnedLeafSet = Set(returned.getLeafSet()) if returnedLeafSet != fullLeafSet: retList = list(returnedLeafSet) retList.sort() fullList = list(fullLeafSet) fullList.sort() raise CipresException(CipresException.MAJOR_BAD_RETURN, CipresException.MINOR_TREE, 'Returned tree has leaf set of %s, but union of argument leaf sets was %s' % (', '.join(map(str, retList)), ', '.join(map(str,fullList)))) except CipresException, ex: raisePyExcepFromCipresExcep(ex) return returned class DebugTreeMatrixDependentWrapper(PIPResInterfaceWrapper): def __init__(self, objRef): PIPResInterfaceWrapper.__init__(self, objRef) self.tree = None self.matrix = None def setTree(self, inTree): try: tree = toIDLTree(inTree) _LOG.debug('Validating argument to setTree') validateTree(tree, treatAsArg = True) self.tree = CipresTree(tree) _LOG.debug('Tree valid calling object reference setTree') retCode = self.objRef.setTree(tree) _LOG.debug('back from object reference setTree') return retCode except CipresException, ex: raisePyExcepFromCipresExcep(ex) def setMatrix(self, matrix): try: validateMatrix(matrix) self.matrix = CipresDiscreteMatrix(matrix) return self.objRef.setMatrix(toIDLMatrix(matrix)) except CipresException, ex: raisePyExcepFromCipresExcep(ex) def _demandTreeAndMatrix(self, methodName): if self.tree is None: raise CipresException(CipresException.MAJOR_API, CipresException.MINOR_TREE, 'SetTree not called before %s' % str(methodName)) if self.matrix is None: raise CipresException(CipresException.MAJOR_API, CipresException.MINOR_DATA_MATRIX, 'SetMatrix not called before %s' % str(methodName)) def _demandSameLeafSet(self, returned): origLeafSet = Set(self.tree.getLeafSet()) returnedLeafSet = Set(returned.getLeafSet()) if returnedLeafSet != origLeafSet: raise CipresException(CipresException.MAJOR_BAD_RETURN, CipresException.MINOR_TREE, 'Returned tree has leaf set of %s, but the tree sent to setTree() the leaf set of %s' % (str(returnedLeafSet), str(origLeafSet))) class DebugTreeImproveWrapper(DebugTreeMatrixDependentWrapper): def __init__(self, objRef): DebugTreeMatrixDependentWrapper.__init__(self, objRef) def improveTree(self, proxyConsumer): try: self._demandTreeAndMatrix('improveTree') _LOG.debug('calling object reference improveTree()') returned = self.objRef.improveTree(proxyConsumer) _LOG.info('Validating tree returned from improveTree') validateTree(returned, treatAsArg = False) self._demandSameLeafSet(returned) return returned except CipresException, ex: raisePyExcepFromCipresExcep(ex) class DebugTreeRefineWrapper(DebugTreeMatrixDependentWrapper): def __init__(self, objRef): DebugTreeMatrixDependentWrapper.__init__(self, objRef) def refineTree(self): try: self._demandTreeAndMatrix('refineTree') returned = self.objRef.refineTree() _LOG.info('Validating tree returned from refineTree') validateTree(returned, treatAsArg = False) self._demandSameLeafSet(returned) return returned except CipresException, ex: raisePyExcepFromCipresExcep(ex) class DebugMatrixAlignWrapper(PIPResInterfaceWrapper): def setTaxa(self, taxa): try: _LOG.debug('MatrixAlign.setTaxa') self._inTaxa = copy.copy(taxa) return self.objRef.setTaxa(taxa) except CipresException, ex: raisePyExcepFromCipresExcep(ex) def setInputMatrix(self, characters): try: _LOG.debug('MatrixAlign.setInputMatrix') validateMatrix(matrix) self.matrix = CipresDiscreteMatrix(matrix) return self.objRef.setInputMatrix(toIDLMatrix(matrix)) except CipresException, ex: raisePyExcepFromCipresExcep(ex) def setGuideTree(self, tree): try: _LOG.debug('MatrixAlign.setGuideTree') tree = toIDLTree(inTree) _LOG.debug('Validating argument to setGuideTree') validateTree(tree, treatAsArg = True) self.tree = CipresTree(tree) _LOG.debug('Tree valid calling object reference setGuideTree') retCode = self.objRef.setGuideTree(tree) _LOG.debug('back from object reference setGuideTree') return retCode except CipresException, ex: raisePyExcepFromCipresExcep(ex) def getTaxa(self): try: _LOG.debug('MatrixAlign.getTaxa') return self.objRef.getTaxa() except CipresException, ex: raisePyExcepFromCipresExcep(ex) def getAlignedMatrix(self): try: _LOG.debug('MatrixAlign.getAlignedMatrix getting aligned matrix') matrix = self.objRef.getAlignedMatrix(); validateMatrix(matrix) _LOG.debug('MatrixAlign.getAlignedMatrix calling getTaxa') self._outTaxa = copy.copy(self.objRef.getTaxa()) _LOG.debug('MatrixAlign.getAlignedMatrix reordering matrix') reorderMatrix(m, self._outTaxa, self._inTaxa) return matrix except CipresException, ex: raisePyExcepFromCipresExcep(ex) def getTree(self, proxyConsumer): try: _LOG.debug('MatrixAlign.getTree') tree = CipresTree(self.objRef.getTree(proxyConsumer)) self._outTaxa = copy.copy(self.objRef.getTaxa()) validateTree(tree, treatAsArg = False) mapLeaves(tree, self._outTaxa, self._inTaxa) return tree except CipresException, ex: raisePyExcepFromCipresExcep(ex) def getScore(self): try: _LOG.debug('MatrixAlign.getScore') return self.objRef.getScore() except CipresException, ex: raisePyExcepFromCipresExcep(ex) objRefTranslator = { 'MatrixAlign': DebugMatrixAlignWrapper, 'TreeDecompose': DebugTreeDecomposeWrapper, 'TreeMerge': DebugTreeMergeWrapper, 'TreeImprove': DebugTreeImproveWrapper, 'TreeRefine': DebugTreeRefineWrapper }