#!/usr/bin/python from PIPRes.util.server_import import * from PIPRes.splits import * from PIPRes.tree import iterChildren from PIPRes.cipres_types import CipresTree, toIDLTree, numberedNewickListToCipresTrees import copy _LOG = cipresGetLogger('pipres.service_impl.server.tree_merge') verbose = False showTrees = False conflictDestroysAllSubTreeStructure = False class VerticalPath: def _findUpperNode(nd, commonLeafSet): lowerPattern = nd.split & commonLeafSet intervening = [] while True: if nd.isLeaf(): return nd, intervening nextNd = None for c in iterChildren(nd): inters = c.split & commonLeafSet if inters: if inters == lowerPattern: nextNd = c else: return nd, intervening intervening.append(nd) assert(nextNd) nd = nextNd _findUpperNode = staticmethod(_findUpperNode) '''A "vertical" path through a tree containing an arbitrary number of intervening nodes (in order of root end to tip).''' def __init__(self, mask , lowerNd, lowerNdChild): self.lowerNd = lowerNd upperNd, trunkNdList = VerticalPath._findUpperNode(lowerNdChild, mask) self.upperNd = upperNd self.trunkNodes = trunkNdList if mask and upperNd: self.split = mask & (upperNd.split) def __str__(self): nd = [self.lowerNd] + self.trunkNodes + [self.upperNd] return ' -> '.join([str(splitToList(i.split, -1, True)) for i in nd]) def collapse(self, leafIntersection): upperNd = self.upperNd if verbose: _LOG.debug('in collapse upperNd =') #upperNd.writeNewick(sys.stdout) _LOG.debug('\n') higherPaths, collapseIrr = partitionChildrenWithLeafSet(upperNd, leafIntersection) #if verbose: #_LOG.debug('higherPaths = ') #for h in higherPaths: # h.upperNd.writeNewick(sys.stdout) # for h in higherPaths: h.lowerNd = self.lowerNd for c in collapseIrr: c.pruneSelf() if len(self.trunkNodes) > 0: upperNd.pruneSelf(False) lowestTrunk = self.trunkNodes[0] lowestTrunk.pruneSelf(False) if conflictDestroysAllSubTreeStructure: lowestTrunk.collapseClade() else: for n in self.trunkNodes[1:]: n.collapseEdge() collapseIrr.extend(lowestTrunk.children) for h in higherPaths: self.lowerNd._addChild(len(h.trunkNodes) and h.trunkNodes[0] or h.upperNd) else: if verbose: _LOG.debug('upperNd.par.children = ') unp = upperNd.par for c in unp.children: _LOG.debug(' %s' % ' '.join(map(str,splitToList(c.split, -1, True)))) if upperNd.isInternal(): upperNd.collapseEdge() if verbose: _LOG.debug( 'after collapse') for c in unp.children: _LOG.debug(' %s' % ' '.join(map(str,splitToList(c.split, -1, True)))) self.lowerNd._addChildren(collapseIrr) return higherPaths, collapseIrr def scmMergePath(self, toModifyNode, toDesPaths, toRecurse, leafIntersection): found = False collapseTM = False while not (found or collapseTM) and len(toDesPaths): collapseTD = False toIterate = copy.copy(toDesPaths) for desPathNumber, toDesPath in enumerate(toIterate): if verbose: _LOG.debug('toDesPath#%d = %s' % (desPathNumber, str(toDesPath))) res = classifyPaths(self, toDesPath, leafIntersection) if res == SplitCompatEnum.identical: _LOG.debug('identical') if len(toDesPath.trunkNodes) > 0: if len(self.trunkNodes) > 0: # collision _LOG.debug('collision') self.upperNd.pruneSelf(False) if conflictDestroysAllSubTreeStructure: self.trunkNodes[0].collapseDescendants() else: for n in self.trunkNodes[1:]: n.collapseEdge() self.trunkNodes[0]._addChild(self.upperNd) toDesPath.upperNd.pruneSelf(collapseDegTwo = False) if conflictDestroysAllSubTreeStructure: toDesPath.trunkNodes[0].collapseDescendants() else: for n in toDesPath.trunkNodes[1:]: n.collapseEdge() self.trunkNodes[0]._addChildren(toDesPath.trunkNodes[0].children) toDesPath.lChild = None toDesPath.trunkNodes[0].pruneSelf() else: self.upperNd._swapPlaces(toDesPath.trunkNodes[0]) toDesPath.upperNd._swapPlaces(self.upperNd) if self.upperNd.isInternal(): toRecurse.append((self.upperNd, toDesPath.upperNd)) toDesPaths.remove(toDesPath) found = True break elif res == SplitCompatEnum.incompat: _LOG.debug('incompat') collapseTM = True collapseTD = True elif res == SplitCompatEnum._firstIsSubset: _LOG.debug('_firstIsSubset') collapseTD = True elif res == SplitCompatEnum._secondIsSubset: _LOG.debug('_secondIsSubset') collapseTM = True if collapseTD: higherPaths, collapsed = toDesPath.collapse(leafIntersection) toModifyNode._stealChildren(collapsed) toDesPaths.remove(toDesPath) toDesPaths.extend(higherPaths) break if collapseTM: break if not found: _LOG.debug('not found') collapseTM = True if collapseTM: higherPaths, collapsed = self.collapse(leafIntersection) return higherPaths return [] def partitionChildrenWithLeafSet(nd, leafIntersection): '''returns VerticalPath list for children that have leaves in the leafIntersection, and list of other children as a tuple. Assumes 'nd' is a significant node in terms of the commonLeafSet''' lowerPattern = nd.split & leafIntersection assert(lowerPattern) paths = [] irrelevantChildren = [] for c in iterChildren(nd): if c.split & leafIntersection: assert(c.split & leafIntersection != lowerPattern) paths.append(VerticalPath(leafIntersection, nd, c)) else: irrelevantChildren.append(c) return paths, irrelevantChildren def moveInterveningNodesToRefEdge(tree, leafIntersection): '''Moves "irrelevant" nodes to the left side of the root maintaining the unrooted topology and root node object. Used to guarantee that the root is "significant" wrt leafIntersection and the path to the reference Node contains all intervening nodes''' root = tree.root keepMovingSplitRep = leafIntersection ^ root.lChild.split # if a node has this split, then there aren't any other children that have any of the intersection taxa while True: for c in iterChildren(root): if c.split & leafIntersection == keepMovingSplitRep: # slide root so that root.children are root.lChild.children c.pruneSelf(collapseDegTwo = False) nextRootChildren = c.children c.lChild = None c._addChildren(root.children) root.lChild = None c.split = root.split ^ c.split root._addChild(c) root._addChildren(nextRootChildren) break if c.split & keepMovingSplitRep: return def rootPortionSCM(toModify, toDestroy, leafIntersection): '''Roots trees at a common leaf, and deals with collisions along that terminal path will be the root's lChild after calling the function.''' lowestCommonLeaf = getFirstBitAsIndex(leafIntersection) toModRefNode = toModify.attachAtLeafByIndex(lowestCommonLeaf) toDestroyRefNode = toDestroy.attachAtLeafByIndex(lowestCommonLeaf) moveInterveningNodesToRefEdge(toModify, leafIntersection) moveInterveningNodesToRefEdge(toDestroy, leafIntersection) if toModRefNode.par.par: if toDestroyRefNode.par.par: #collision if conflictDestroysAllSubTreeStructure: toDestroy.root.lChild.collapseDescendants() toModify.root.lChild.collapseDescendants() else: #reduce to two edges on the refNode to root path while toDestroyRefNode.par.par.par: toDestroyRefNode.par.collapseEdge() while toModRefNode.par.par.par: toModRefNode.par.collapseEdge() p = toModRefNode.par toSteal = [c for c in iterChildren(toDestroyRefNode.par) if c != toDestroyRefNode] toDestroyRefNode.par.lChild = toDestroyRefNode toDestroyRefNode.rSib = None p._addChildren(toSteal) else: if toDestroyRefNode.par.par: toModRefNode._swapPlaces(toDestroy.root.lChild) toDestroyRefNode._swapPlaces(toModRefNode) # for completeness, we'll move the "correct" reference node back onto it's original tree def classifyPaths(firPath, secPath, leafIntersection): #print 'classifyPaths(' + str(splitToList(firPath.split, -1, True)) + ', ' + str(splitToList(secPath.split, -1, True)) + ')' comp = compareSplits(firPath.split, secPath.split, leafIntersection, False)[0] if comp != SplitCompatEnum.compat: return comp intersec = firPath.split & secPath.split if intersec == firPath.split: return SplitCompatEnum._firstIsSubset return intersec == secPath.split and SplitCompatEnum._secondIsSubset or SplitCompatEnum.compat def scmMergeSubTree(toModifyNode, toDestroyNode, leafIntersection): #print 'leafIntersection =', splitToList(leafIntersection, -1, True) toModPaths, toModOtherNds = partitionChildrenWithLeafSet(toModifyNode, leafIntersection) toDesPaths, toDestroyOtherNds = partitionChildrenWithLeafSet(toDestroyNode, leafIntersection) if len(toDestroyOtherNds) > 0: toModifyNode._stealChildren(toDestroyOtherNds) toRecurse = [] #print 'toModPaths =\n\t', '\n\t'.join([str(p) for p in toModPaths]) while 1: newPaths = [] for toModPath in toModPaths: if verbose: _LOG.debug('toModPath = %s' % toModPath) _LOG.debug('toDesPaths = %s' % '\n '.join(map(str, toDesPaths))) newPaths.extend(toModPath.scmMergePath(toModifyNode, toDesPaths, toRecurse, leafIntersection)) if verbose: _LOG.debug('\n') if not newPaths: break else: toModPaths = newPaths for p in toRecurse: scmMergeSubTree(p[0], p[1], leafIntersection) def scm(treeToModify, treeToDestroy): leafIntersection = treeToModify.root.split & treeToDestroy.root.split _LOG.debug('leaf intersection = %s' % ' '.join(map(str,splitToList(leafIntersection, -1, True)))) nCommonLeaves = countBits(leafIntersection) if nCommonLeaves < 2: _LOG.error('trees must have at least 2 common leaves') raise ValueError, 'trees must have at least 2 common leaves' if nCommonLeaves == 2: # return a polytomy treeToModify.root.collapseClade() treeToDestroy.root.collapseClade() treeToModify.root._stealChildren([c for c in iterChildren(treeToDestroy.root) if not (leafIntersection & c.split)]) else: rootPortionSCM(treeToModify, treeToDestroy, leafIntersection) if verbose and showTrees: treeToModify.show('toModPostRoot.dot') treeToDestroy.show('toDestroyPostRoot.dot') if verbose: _LOG.debug('treeToDestroy.root.children = ') for c in treeToDestroy.root.children: _LOG.debug(' %s' % ' '.join(map(str,splitToList(c.split, -1, True)))) # prune the reference edge that we have already dealt with toReattach = treeToModify.root.lChild toReattach.pruneSelf(collapseDegTwo = False) treeToDestroy.root.lChild.pruneSelf(collapseDegTwo = False) scmMergeSubTree(treeToModify.root, treeToDestroy.root, leafIntersection) # reattach the root's left child toReattach.rSib = treeToModify.root.lChild toReattach.par = treeToModify.root treeToModify.root.lChild = toReattach del toReattach treeToModify.root.refreshSplits() treeToModify.taxaManager |= treeToDestroy.taxaManager if hasattr(treeToModify, 'leafSet'): del treeToModify.leafSet return treeToModify class SCMTreeMerge(CipresIDL_api1__POA.TreeMerge, SimpleServer): def __init__(self, registry = None): SimpleServer.__init__(self, registry) self.convertTreeFunc = self.orb is None and toTreeBridge or toIDLTree def mergeTrees(self, treeList): nTrees = len(treeList) _LOG.debug('%d Trees to merge:\n%s\n' % (nTrees, '\n'.join([i.m_newick for i in treeList]))) if nTrees == 0: _LOG.error('Raising CORBA.BAD_PARAM(omniORB.BAD_PARAM_WrongPythonType, COMPLETED_NO) because the sequence length is 0, we need a user exception for this !!') raise CORBA.BAD_PARAM(omniORB.BAD_PARAM_WrongPythonType, COMPLETED_NO) if nTrees == 1: _LOG.debug('Returning %s\n' % treeList[0].m_newick) return CipresTree(treeList[0]) toMerge = iter(numberedNewickListToCipresTrees(treeList)) targetTree = toMerge.next() targetTree.ignoreEdgeLengths = True targetTree.hasEdgeLengths = False for n,i in enumerate(toMerge): if showTrees: CipresTree(targetTree).show('firArg%d.dot' % n) CipresTree(i).show('secArg%d.dot' % n) scm(targetTree, i) if showTrees: CipresTree(targetTree).show('final.dot') _LOG.debug('Returning %s\n' % targetTree.m_newick) return self.convertTreeFunc(targetTree) if __name__=='__main__': if 0: print 'DEBUGGING Run' cipresTrees = [CipresTree('(14:8.0,82:9.0,(((85:31.0,((((134:28.0,135:24.0):5.0,(136:12.0,137:14.0):7.0):13.0,(244:28.0,(322:18.0,347:19.0):14.0):9.0):3.0,(((160:11.0,174:8.0):3.0,162:21.0):3.0,(189:12.0,203:12.0):2.0):10.0):9.0):6.0,158:7.0):4.0,92:9.0):4.0);'), CipresTree('((((397:39.0,436:37.0):23.0,471:34.0):38.0,((330:30.0,366:35.0):25.0,392:44.0):15.0):0,((((481:47.0,491:36.0):27.0,363:35.0):0,353:25.0):0,((((300:38.0,163:16.0):0,82:57.0):0,(((22:35.0,152:35.0):0,(48:29.0,244:21.0):0):0,((153:45.0,189:35.0):0,151:19.0):0):0):5.0,((170:56.0,312:47.0):32.0,((((18:25.0,(41:29.0,301:11.0):3.0):11.0,(107:24.0,160:18.0):9.0):0,418:47.0):0,((441:36.0,(349:31.0,((203:9.0,208:9.0):5.0,222:10.0):14.0):0):0,(92:43.0,168:21.0):0):0):24.0):0):0):0,373:27.0);'), CipresTree('(((173:16.0,108:18.0):0,((154:18.0,182:11.0):3.0,(((160:4.0,164:4.0):13.0,(183:20.0,186:34.0):11.0):0,(123:34.0,(126:26.0,150:47.0):6.0):10.0):0):0):3.0,((((((189:7.0,(203:9.0,244:15.0):2.0):13.0,252:30.0):16.0,104:18.0):0,(95:5.0,82:19.0):0):0,((34:15.0,52:23.0):0,92:17.0):0):2.0,(24:20.0,(((292:19.0,35:30.0):0,209:23.0):0,7:33.0):0):3.0):0,(((398:18.0,((118:25.0,369:13.0):4.0,((233:21.0,340:11.0):15.0,480:23.0):5.0):0):0,(444:27.0,(384:24.0,411:18.0):0):0):0,(446:26.0,393:19.0):0):8.0);'), CipresTree('(82:8.0,(92:10.0,160:14.0):2.0,(189:26.0,(203:18.0,244:19.0):16.0):17.0);') ] tmInstance = SCMTreeMerge(None) from PIPRes.wrap.debug_stub import DebugTreeMergeWrapper checkingTM = DebugTreeMergeWrapper(tmInstance) t = checkingTM.mergeTrees(cipresTrees) print 'returned', t.m_newick sys.exit(0) try: cipresServe(sys.argv, {'TreeMerge': SCMTreeMerge}) except: logException(_LOG)