#/usr/local/bin/python from PIPRes.util.io import cipresGetLogger _LOG = cipresGetLogger('modelSpec') def iterCells(nRows, nCols = None, starting = [0,0]): if nCols is None: nCols = nRows currRow = starting[0] currCol = starting[1] if currRow >= nRows: return if currRow == nRows +1 and currCol >= nCols: return while currRow < nRows: yield [currRow, currCol] currCol += 1 if currCol >= nCols: currCol = 0 currRow += 1 if False: class ParameterUser(object): def __init__(self, paramContext = None): self._paramContext = paramContext def getParameters(self): return self._paramContext is not None and self._paramContext.getParams(self) or [] def __str__(self): return self.__class__.__name__ class ContinuousDistribution(ParameterUser): def getLnDensity(self,x ): dens = -1.0 _LOG.debug('%s getLnDensity(%f) = %f' % (str(self), x, dens)) return dens class JointDistribution(ParameterUser): def getLnDensity(self, x): dens = -1.0 _LOG.debug('%s getLnDensity(%s) = %f' % (str(self), str(x), dens)) return dens class Parameter(ParameterUser): def __init__(self, v, **kwargs): self.distribution = kwargs.get('distribution') self.value = kwargs.get('value') self.group = kwargs.get('group') def getValue(self): return self._value def setValue(self, x): self._value = x value = property(getValue, setValue) def getDistribution(self): return self.distribution def getGroup(self): return None class ParmeterGroup(ParameterUser): def __init__(self, params = [], **kwargs): self.params = params for i in self.params: i.group = self self.distribution = kwargs.get('distribution') def getJointDistribution(self): return self.distribution def getParameters(self): return self.params class Substitution(object): def __init__(self, fromState, toState, datatype = None): self.cell = [fromState, toState] self.datatype def __str__(self): if self.datatype is not None: sym = self.datatype.symbols return '%c->%c' % (sym[self.cell[0]], sym[self.cell[1]]) return '%d->%d' % (self.cell[0], self.cell[1]) class Datatype(object): def __init__(self, symbols): self.symbols = symbols class SubstitutionClassification(ParameterHolder): def __init__(self, substitMat, datatype = None, paramContext = None): super(SubstitutionClassification, self).__init__(paramContext) self.datatype = datatype self.substMat = substitMat for i in self.substMat: i.datatype = datatype self._paramContext.addParamSlots(self, self.nParams) def getNParams(self): raise NotImplementedError nParams = property(getNParams) def _scListFromList(p): '''Takes a matrix of object and returns a substitution class matrix assuming using object identity''' nRows = len(p) if nRows == 0: return [[]], p if nRows == 1: return [[Substitution(0,1), Substitution(1,0)]], p nCurrCol = len(p[0]) nSecCol = len(p[1]) if nSecCol == nCurrCol: skipDiag = nCurrCol == nRows assert(skipDiag or nCurrCol == nRows - 1) elif nSecCol == nCurrCol - 1: x elif nSecCol == nCurrCol + 1: assert(nCurrCol == 1) raise ValueError, 'Row lengths do not conform to upper triangle, lower triangle, or full matrix' #/* bare minimum needed to calculate a transition probability */ class SiteModelPrimitive(ParameterHolder): def __init__(self, **kwargs): super(SubstitutionClassification, self).__init__(kwargs.get('paramContext')) sc = kwargs.get('substClass') self.substClass = None if sc is not None: self._paramContext.addLevel(self, children = substClass) self.substClass = substClass p = kwargs.get('parameters') if p is not None: scList, condensed = _scListFromList(p) if self.substClass is None: self.substClass = SubstitutionClassification(scList, self.datatype, self._paramContext) else: assert(len(condensed) == self.substClass.nParams) assert(self._paramContext) self._paramContext.loadParmeters(substClass, condensed) def isReversible(self): return False def getRateMatrix(self): raise NotImplementedError class ReversibleSiteModelPrimitive(SiteModelPrimitive): def __init__(self, substClass, stateFreq, paramContext): self.stateFreq = stateFreq super(SubstitutionClassification, self).__init__(substClass, paramContext) self._paramContext.addParamSlots(self, self.nParams) def isReversible(self): return True class NonReversibleSiteModelPrimitive(SiteModelPrimitive): #SubstitutionClassification rmat; pass class StandardizedSiteModel: # SiteModelPrimitive modelShape; #Param relativeRate; pass class MixtureModel(StandardizedSiteModel): # ParamGroup getMixtureProportionGroup; #ModelSelector getModelSelector() # StandardizedSiteModel [] components; pass class TreeRegion: #enum RegionDivsion {kWholeDamnTree, kByEdges} #sequence edges; #sequence pass class ModelOnTree: # StandardizedSiteModel model #MatrixRange chars #TreeRegion partOfTree pass class SuperMatrixModel: # ModelOnTree [] subsetModels; pass