#!/usr/bin/python from parser import NexusBlock, NexusBlockStream, NexusParsing from primitives import * from cipres.util.io import * from command_reader import NexusCommandReader from sets import Set from cipres.tree import new_tree import os def getTreesFromNexus(inF, getTaxaFromAllPublic = True, **kwdsToTreesBlock): tList = [] if getTaxaFromAllPublic: import public_blocks handlerDict = copy.copy(public_blocks.ALL_PUBLIC_BLOCKS) public_blocks.addBlockInitializationArgs(handlerDict, ['TREES'], kwdsToTreesBlock) else: handlerDict = {'TREES': (NexusTreesBlock, kwdsToTreesBlock)} for b in NexusBlockStream(inF, handlerDict, True, []): try: tList.extend(b.getTrees()) except AttributeError: pass return tList def getTreesFromNexusString(s, getTaxaFromAllPublic = True, **kwdsToTreesBlock): import cStringIO return getTreesFromNexus(cStringIO.StringIO(s), getTaxaFromAllPublic, **kwdsToTreesBlock) def getTreesFromNexusFileName(inFilename, getTaxaFromAllPublic = True, **kwdsToTreesBlock): return openInPathAndCall(inFilename, getTreesFromNexus, **kwdsToTreesBlock) class NexusTreesTranslateReader: def __init__(self, treesBlock): self.treesBlock = treesBlock def readCommand(self, translateToken, cStream, obj = None, blockObj = None): if blockObj is not None: self.treesBlock = blockObj if self.treesBlock.__dict__.has_key('labelTranslation') and len(self.treesBlock.labelTranslation) > 0: msg = "Translation table already constructed: Translate command"\ " (if used) must occur before any trees (and there can only"\ " be on translate command)." raise NexusError(msg, token=translateToken) tok_stream = cStream.getTokenStream() translateDict = {} k, v = str(tok_stream.next()), str(tok_stream.next()) transKeyOrder = [] while True: if translateDict.has_key(k): msg = "Translate keys cannot occur more than once (%s was"\ " duplicated)" % str(k) raise NexusError(msg, token=k) translateDict[k] = v transKeyOrder.append(k) k = tok_stream.next() sk = str(k) if sk == ';': break if sk != ',': raise NexusMissingTokenError(', or ;', sk, token=k) k, v = str(tok_stream.next()), str(tok_stream.next()) self.treesBlock.addTranslateTable(translateDict, transKeyOrder) return True class NexusTreeReader: import re treeNamePattern = re.compile(r'.*[a-zA-Z0-9]+.*') def __init__(self, treesBlock): self.treesBlock = treesBlock def isValidTreeName(s): return NexusTreeReader.treeNamePattern.match(s) isValidTreeName = staticmethod(isValidTreeName) def readCommand(self, treeCommandNameToken, cStream, obj = None, blockObj = None): if blockObj is not None: self.treesBlock = blockObj if self.treesBlock: self.treesBlock.chooseTaxaManager() tok_stream = cStream.getTokenStream() n = tok_stream.next() treeName = str(n) if not NexusTreeReader.isValidTreeName(treeName): raise NexusIllegalLabelError('Tree', '', n) tok_stream.require_next_token('=') t = new_tree(newick=tok_stream, name = treeName, taxa_mgr = self.treesBlock) self.treesBlock.addTree(t) return True class NexusTreesBlock(NexusBlock, ContainedTaxaManager): cmdHandlers = [ NexusCommandReader('Translate', readerToCreate = NexusTreesTranslateReader), NexusCommandReader('Tree', readerToCreate = NexusTreeReader) ] def __init__(self, beginCmd = None, commandStream = None, previousBlocks = None, **kwds): """In kwds uses taxa_mgr if present, if not used tax_naming_style and tax_labels.""" self.taxa_mgr = get_taxa_manager_from_args(kwds) taxNS = kwds.get('tax_naming_style', TaxonNaming.ACCEPT_NUMBERS) self.treatUknownNumbersAsIndices = (taxNS == TaxonNaming.NUMBERS_ONLY) or (taxNS == TaxonNaming.LABELS_OR_NUMBERS) self.labelTranslation = {} if commandStream is None: self.prepareToRead(previousBlocks or []) self.allowDifferingLeafSets = False NexusBlock.__init__(self, beginCmd, commandStream, previousBlocks or []) def chooseTaxaManager(self): """ called right before a tree is read. Hook to set the taxa block (if one is unambiguous)""" if len(self.labelTranslation) > 0 or (hasattr(self,'taxa_mgr') and self.taxa_mgr): return if len(self.potentialTaxaBlocks) > 0: firstTaxaBlock = self.potentialTaxaBlocks[0] for p in self.potentialTaxaBlocks[1:]: if firstTaxaBlock != p: return # ambiguous taxa block don't use any of the previous blocks self.addNewLabels(firstTaxaBlock.get_tax_labels()) self.taxa_mgr = firstTaxaBlock def prepareToRead(self, previousBlocks): """ prepare to read a block.""" self.potentialTaxaBlocks = [] for p in previousBlocks: try: tl = p.get_tax_labels() self.potentialTaxaBlocks.append(p) except AttributeError: pass self.labelTranslation = {} self.trees = [] def getTrees(self): return self.trees def addNewLabels(self, labels): if self.taxa_are_final(): raise AssertionError, 'Adding taxa to finalized taxa manager.' n = len(self.get_tax_labels()) for i, l in enumerate(labels): ind = n + i self.labelTranslation[str(ind + 1)] = ind self.labelTranslation[l.upper()] = ind self.taxa_mgr.add_taxa(labels) def addTree(self, t): if len(self.__dict__.get('new_labels', []))> 0: self.addNewLabels(self.new_labels) del self.new_labels self.trees.append(t) def translate_tax_label(self, tax_label): capLabel = str(tax_label).upper() if hasattr(self, 'labelTranslation'): ind = self.labelTranslation.get(capLabel) if ind is not None: return ind try: return self.taxa_mgr.translate_tax_label(tax_label) except NexusUnknownTaxonLabelError: pass tl = self.taxa_mgr.get_tax_labels() if not self.taxa_mgr.taxa_are_final(): if not self.__dict__.has_key('new_labels'): self.new_labels = [] else: try: return index_of_first_match(self.new_labels, lambda u: capLabel == u.upper()) except ValueError: pass self.new_labels.append(str(tax_label)) if not capLabel.isdigit() or not self.treatUknownNumbersAsIndices: return len(tl) + len(self.new_labels) - 1 else: self.taxa_mgr.validate_tax_label(capLabel, []) return int(capLabel) - 1 raise NexusUnknownTaxonLabelError(tax_label, tl) def addTranslateTable(self, translateDict, translateKeyOrder): if self.__dict__.has_key('tax_labels'): NexusTreesBlock.validateTranslateTable(translateDict, self, translateKeyOrder) else: for t in self.potentialTaxaBlocks: try: if NexusTreesBlock.validateTranslateTable(translateDict, t, translateKeyOrder): self.taxa_mgr = t break except NexusUnknownTaxonLabelError: pass if not self.__dict__.has_key('tax_labels'): self.tax_labels = [translateDict[k] for k in translateKeyOrder] if len(Set(self.tax_labels)) != len(self.tax_labels): self.tax_labels = stable_unique(self.tax_labels) self.createLabelTranslationDict(translateDict, translateKeyOrder) def createLabelTranslationDict(self, translateDict, translateKeyOrder): self.labelTranslation = {} for k in translateKeyOrder: v = translateDict[k] inds = self.translate_tax_label(v) self.labelTranslation[k.upper()] = inds self.labelTranslation[v.upper()] = inds for i in range(self.get_n_tax()): self.labelTranslation[str(i + 1)] = i lab = self.get_tax_label(i).upper() self.labelTranslation[lab] = i def validateTranslateTable(translateDict, taxaContext, translateKeyOrder): for k in translateKeyOrder: taxaContext.translate_tax_label(translateDict[k]) return True validateTranslateTable = staticmethod(validateTranslateTable)