#!/usr/bin/python from parser import NexusBlock, NexusBlockStream, NexusParsing from primitives import * from PIPRes.util.io import * from command_reader import NexusCommandReader from sets import Set from PIPRes.tree import Tree import os def getTreesFromNexus(inF, getTaxaFromAllPublic = True, **kwdsToTreesBlock): tList = [] if getTaxaFromAllPublic: import public_blocks handlerDict = copy.copy(public_blocks.ALL_PUBLIC_BLOCKS) public_blocks.addBlockInitializationArgs(handlerDict, ['TREES'], kwdsToTreesBlock) else: handlerDict = {'TREES': (NexusTreesBlock, kwdsToTreesBlock)} for b in NexusBlockStream(inF, handlerDict, True, []): try: tList.extend(b.getTrees()) except AttributeError: pass return tList def getTreesFromNexusString(s, getTaxaFromAllPublic = True, **kwdsToTreesBlock): import cStringIO return getTreesFromNexus(cStringIO.StringIO(s), getTaxaFromAllPublic, **kwdsToTreesBlock) def getTreesFromNexusFileName(inFilename, getTaxaFromAllPublic = True, **kwdsToTreesBlock): return openInPathAndCall(inFilename, getTreesFromNexus, **kwdsToTreesBlock) class NexusTreesTranslateReader: def __init__(self, treesBlock): self.treesBlock = treesBlock def readCommand(self, translateToken, cStream, obj = None, blockObj = None): if blockObj is not None: self.treesBlock = blockObj if self.treesBlock.__dict__.has_key('labelTranslation') and len(self.treesBlock.labelTranslation) > 0: raise NexusAfterTokenError(translateToken, 'Translation table already constructed: Translate command (if used) must occur before any trees (and there can only be on translate command).') tokStream = cStream.getTokenStream() translateDict = {} k, v = str(tokStream.next()), str(tokStream.next()) transKeyOrder = [] while True: if translateDict.has_key(k): raise NexusAfterTokenError(k, 'Translate keys cannot occur more than once (%s was duplicated)' % str(k)) translateDict[k] = v transKeyOrder.append(k) k = tokStream.next() if str(k) == ';': break if str(k) != ',': raise NexusMissingTokenError(', or ;', k) k, v = str(tokStream.next()), str(tokStream.next()) self.treesBlock.addTranslateTable(translateDict, transKeyOrder) return True class NexusTreeReader: import re treeNamePattern = re.compile(r'.*[a-zA-Z0-9]+.*') def __init__(self, treesBlock): self.treesBlock = treesBlock def isValidTreeName(s): return NexusTreeReader.treeNamePattern.match(s) isValidTreeName = staticmethod(isValidTreeName) def readCommand(self, treeCommandNameToken, cStream, obj = None, blockObj = None): if blockObj is not None: self.treesBlock = blockObj if self.treesBlock: self.treesBlock.chooseTaxaManager() tokStream = cStream.getTokenStream() n = tokStream.next() treeName = str(n) if not NexusTreeReader.isValidTreeName(treeName): raise NexusIllegalName('Tree', '', n) tokStream.demandToken('=') t = Tree(name = treeName, taxaManager = self.treesBlock) t.readTokenStream(tokStream) self.treesBlock.addTree(t) return True class NexusTreesBlock(NexusBlock, ContainedTaxaManager): cmdHandlers = [ NexusCommandReader('Translate', readerToCreate = NexusTreesTranslateReader), NexusCommandReader('Tree', readerToCreate = NexusTreeReader) ] def __init__(self, beginCmd = None, commandStream = None, previousBlocks = None, **kwds): '''In kwds uses taxaManager if present, if not used taxaNamingStyle and taxLabels.''' self.taxaManager = getTaxaManagerFromDictArgs(kwds) taxNS = kwds.get('taxaNamingStyle', TaxaNamingEnum.acceptNumbers) self.treatUknownNumbersAsIndices = (taxNS == TaxaNamingEnum.numbersOnly) or (taxNS == TaxaNamingEnum.labelsOrNumbers) self.labelTranslation = {} if commandStream is None: self.prepareToRead(previousBlocks or []) self.allowDifferingLeafSets = False NexusBlock.__init__(self, beginCmd, commandStream, previousBlocks or []) def chooseTaxaManager(self): ''' called right before a tree is read. Hook to set the taxa block (if one is unambiguous)''' if len(self.labelTranslation) > 0 or (hasattr(self,'taxaManager') and self.taxaManager): return if len(self.potentialTaxaBlocks) > 0: firstTaxaBlock = self.potentialTaxaBlocks[0] for p in self.potentialTaxaBlocks[1:]: if firstTaxaBlock != p: return # ambiguous taxa block don't use any of the previous blocks self.addNewLabels(firstTaxaBlock.getTaxLabels()) self.taxaManager = firstTaxaBlock def prepareToRead(self, previousBlocks): ''' prepare to read a block.''' self.potentialTaxaBlocks = [] for p in previousBlocks: try: tl = p.getTaxLabels() self.potentialTaxaBlocks.append(p) except AttributeError: pass self.labelTranslation = {} self.trees = [] def getTrees(self): return self.trees def addNewLabels(self, labels): if self.taxaAreFinal(): raise AssertionError, 'Adding taxa to finalized taxa manager.' n = len(self.getTaxLabels()) for i, l in enumerate(labels): ind = n + i self.labelTranslation[str(ind + 1)] = ind self.labelTranslation[l.upper()] = ind self.taxaManager.addTaxa(labels) def addTree(self, t): if len(self.__dict__.get('newLabels', []))> 0: self.addNewLabels(self.newLabels) del self.newLabels self.trees.append(t) def translateTaxLabel(self, tLabel): capLabel = str(tLabel).upper() if hasattr(self, 'labelTranslation'): ind = self.labelTranslation.get(capLabel) if ind is not None: return ind try: return self.taxaManager.translateTaxLabel(tLabel) except NexusUnknownTaxonError: pass tl = self.taxaManager.getTaxLabels() if not self.taxaManager.taxaAreFinal(): if not self.__dict__.has_key('newLabels'): self.newLabels = [] else: try: return index(capLabel, self.newLabels, lambda c, u: c == u.upper()) except ValueError: pass self.newLabels.append(str(tLabel)) if not capLabel.isdigit() or not self.treatUknownNumbersAsIndices: return len(tl) + len(self.newLabels) - 1 else: self.taxaManager.validateTaxonName(capLabel, []) return int(capLabel) - 1 raise NexusUnknownTaxonError(tLabel, tl) def addTranslateTable(self, translateDict, translateKeyOrder): if self.__dict__.has_key('taxLabels'): NexusTreesBlock.validateTranslateTable(translateDict, self, translateKeyOrder) else: for t in self.potentialTaxaBlocks: try: if NexusTreesBlock.validateTranslateTable(translateDict, t, translateKeyOrder): self.taxaManager = t break except NexusUnknownTaxonError: pass if not self.__dict__.has_key('taxLabels'): self.taxLabels = [translateDict[k] for k in translateKeyOrder] if len(Set(self.taxLabels)) != len(self.taxLabels): self.taxLabels = stableUnique(self.taxLabels) self.createLabelTranslationDict(translateDict, translateKeyOrder) def createLabelTranslationDict(self, translateDict, translateKeyOrder): self.labelTranslation = {} for k in translateKeyOrder: v = translateDict[k] inds = self.translateTaxLabel(v) self.labelTranslation[k.upper()] = inds self.labelTranslation[v.upper()] = inds for i in range(self.getNTax()): self.labelTranslation[str(i + 1)] = i lab = self.getTaxLabel(i).upper() self.labelTranslation[lab] = i def validateTranslateTable(translateDict, taxaContext, translateKeyOrder): for k in translateKeyOrder: taxaContext.translateTaxLabel(translateDict[k]) return True validateTranslateTable = staticmethod(validateTranslateTable)