Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions prody/tests/datafiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
package that contains test modules and files as well."""


from os.path import join, isfile, split, splitext
from os.path import join, isfile, splitext
from prody.tests import TestCase

from numpy import array
import numpy as np

from prody import parsePDB, parseDCD, parseMMCIF, parseMMTF
from prody import parseSparseMatrix, parseArray, loadModel
from prody.tests import TEMPDIR, TESTDIR
from prody import parseSparseMatrix, parseArray, parseTree, loadModel
from prody.tests import TEMPDIR, TESTDIR # here for others to import


DATA_FILES = {
Expand Down Expand Up @@ -453,6 +453,16 @@
'n_atoms': 4,
'long_resname': 'ACET',
'short_resname': 'ACE'
},
'upgma_tree': {
'file': 'simple_tree_upgma.nwk',
'n_leaves': 4,
'n_top_clades': 2,
},
'nj_tree': {
'file': 'simple_tree_nj.nwk',
'n_leaves': 4,
'n_top_clades': 3,
}
}

Expand All @@ -463,7 +473,8 @@
'.coo': parseSparseMatrix, '.dat': parseArray,
'.txt': np.loadtxt,
'.npy': lambda fn, **kwargs: np.load(fn, allow_pickle=True),
'.gz': lambda fn, **kwargs: PARSERS[splitext(fn)[1]](fn, **kwargs)
'.gz': lambda fn, **kwargs: PARSERS[splitext(fn)[1]](fn, **kwargs),
'.nwk': parseTree
}


Expand Down
1 change: 1 addition & 0 deletions prody/tests/datafiles/simple_tree_nj.nwk
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(A:0.25000,(C:1.00000,B:0.50000):0.50000,D:0.75000):0.00000;
1 change: 1 addition & 0 deletions prody/tests/datafiles/simple_tree_upgma.nwk
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
((C:0.75000,B:0.75000):0.12500,(D:0.50000,A:0.50000):0.37500):0.00000;
227 changes: 227 additions & 0 deletions prody/tests/utilities/test_catchall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""This module contains some unit tests for :mod:`prody.utilities.catchall` module,
starting with tree-related ones."""

import os
import tempfile
import numpy as np

from prody.tests import unittest
from prody.utilities import calcTree, findSubgroups, writeTree, parseTree
from prody.tests.datafiles import parseDatafile, pathDatafile

class TestCalcTree(unittest.TestCase):

def testCalcTreeUPGMA(self):
"""Test calcTree with UPGMA method."""
names = ['A', 'B', 'C', 'D']
distance_matrix = np.array([[0, 1, 2, 1],
[1, 0, 1.5, 2],
[2, 1.5, 0, 2],
[1, 2, 2, 0]])
tree = calcTree(names, distance_matrix, method='upgma')
self.assertIsNotNone(tree)
# Check that tree has 4 leaves and they include the names
leaves = tree.get_terminals()
self.assertEqual(len(leaves), 4)
self.assertEqual(set([leaf.name for leaf in leaves]), set(names))
# Check that the tree has split evenly as expected for UPGMA
self.assertEqual(len(tree.root.clades), 2)

def testCalcTreeNJ(self):
"""Test calcTree with NJ method."""
names = ['A', 'B', 'C', 'D']
distance_matrix = np.array([[0, 1, 2, 1],
[1, 0, 1.5, 2],
[2, 1.5, 0, 2],
[1, 2, 2, 0]])
tree = calcTree(names, distance_matrix, method='nj')
self.assertIsNotNone(tree)
leaves = tree.get_terminals()
# Check that tree has 4 leaves and they include the names
self.assertEqual(len(leaves), 4)
self.assertEqual(set([leaf.name for leaf in leaves]), set(names))
# Check that the tree has split unevenly as expected for NJ
self.assertEqual(len(tree.root.clades), 3)

def testCalcTreeMismatchSize(self):
"""Test calcTree with mismatched names and matrix sizes."""
names = ['A', 'B']
distance_matrix = np.array([[0, 1, 2],
[1, 0, 1.5],
[2, 1.5, 0]])
with self.assertRaises(ValueError):
calcTree(names, distance_matrix)


class TestFindSubgroups(unittest.TestCase):

def setUp(self):
"""Set up a test tree for findSubgroups tests."""
# Create a simple distance matrix with clear clustering
# Points A,B are close (distance 0.5), C,D are close (distance 0.5)
# But A,B are far from C,D (distance 5)
self.names = ['A', 'B', 'C', 'D']
self.distance_matrix = np.array([[0.0, 0.5, 5.0, 5.0],
[0.5, 0.0, 5.0, 5.0],
[5.0, 5.0, 0.0, 0.5],
[5.0, 5.0, 0.5, 0.0]])
self.tree = calcTree(self.names, self.distance_matrix, method='upgma')

def testFindSubgroupsNaiveMethod(self):
"""Test findSubgroups with naive method."""
# Using cutoff 2.0 should separate into 2 subgroups
subgroups = findSubgroups(self.tree, 2.0, method='naive')
self.assertIsNotNone(subgroups)
self.assertEqual(len(subgroups), 2)
# Check that subgroups contain the expected names
all_names = [name for subgroup in subgroups for name in subgroup]
self.assertEqual(set(all_names), set(self.names))

def testFindSubgroupsNaiveLargeCutoff(self):
"""Test findSubgroups with naive method and large cutoff."""
# Using cutoff 10.0 should keep everything in one subgroup
subgroups = findSubgroups(self.tree, 10.0, method='naive')
self.assertEqual(len(subgroups), 1)
self.assertEqual(set(subgroups[0]), set(self.names))

def testFindSubgroupsNaiveTinyCutoff(self):
"""Test findSubgroups with naive method and tiny cutoff."""
# Using cutoff 0.1 should separate all into individual subgroups
subgroups = findSubgroups(self.tree, 0.1, method='naive')
self.assertEqual(len(subgroups), 4)
# Each subgroup should have one member
for subgroup in subgroups:
self.assertEqual(len(subgroup), 1)

def testFindSubgroupsReturnsListOfLists(self):
"""Test that findSubgroups returns a list of lists."""
subgroups = findSubgroups(self.tree, 2.0, method='naive')
self.assertIsInstance(subgroups, list)
for subgroup in subgroups:
self.assertIsInstance(subgroup, list)


class TestParseTree(unittest.TestCase):

def testParseUPGMATree(self):
"""Test parsing an UPGMA tree from a file."""
tree_fn = pathDatafile('upgma_tree')
tree = parseTree(tree_fn)
self.assertIsNotNone(tree)
# Check that tree has expected number of leaves
leaves = tree.get_terminals()
self.assertEqual(len(leaves), 4)
# Check that tree has expected number of top-level clades
self.assertEqual(len(tree.root.clades), 2)

def testParseNJTree(self):
"""Test parsing a neighbor-joining tree from a file."""
tree_fn = pathDatafile('nj_tree')
tree = parseTree(tree_fn)
self.assertIsNotNone(tree)
# Check that tree has expected number of leaves
leaves = tree.get_terminals()
self.assertEqual(len(leaves), 4)
# Check that tree has expected number of top-level clades
self.assertEqual(len(tree.root.clades), 3)

def testParseTreeTreeType(self):
"""Test that parseTree returns a Biopython Tree object."""
try:
from Bio import Phylo
tree = parseDatafile('upgma_tree')
self.assertIsInstance(tree, Phylo.BaseTree.Tree)
except ImportError:
self.skipTest("Biopython not available")

def testParseTreeWrongFilepath(self):
"""Test parseTree with non-existent file."""
with self.assertRaises((AssertionError, FileNotFoundError)):
parseTree('/nonexistent/path/to/tree.nwk')

def testParseTreeWrongFileType(self):
"""Test parseTree with invalid filename argument."""
with self.assertRaises(TypeError):
parseTree(123)


class TestWriteTree(unittest.TestCase):

def setUp(self):
"""Set up test trees for writing."""
self.upgma_tree = parseDatafile('upgma_tree')
self.nj_tree = parseDatafile('nj_tree')
# Create a temporary directory for test files
self.temp_dir = tempfile.mkdtemp()

def tearDown(self):
"""Clean up temporary test files."""
import shutil
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

def testWriteUPGMATree(self):
"""Test writing an UPGMA tree to a file."""
output_file = os.path.join(self.temp_dir, 'test_upgma.nwk')
try:
writeTree(output_file, self.upgma_tree)
self.assertTrue(os.path.exists(output_file))
# Verify the file is not empty
with open(output_file, 'r') as f:
content = f.read()
self.assertTrue(len(content) > 0)
# Check for Newick format markers
self.assertIn(';', content)
except ImportError:
self.skipTest("Biopython not available")

def testWriteNJTree(self):
"""Test writing a neighbor-joining tree to a file."""
output_file = os.path.join(self.temp_dir, 'test_nj.nwk')
try:
writeTree(output_file, self.nj_tree)
self.assertTrue(os.path.exists(output_file))
# Verify the file is not empty
with open(output_file, 'r') as f:
content = f.read()
self.assertTrue(len(content) > 0)
# Check for Newick format markers
self.assertIn(';', content)
except ImportError:
self.skipTest("Biopython not available")

def testWriteTreeWrongFilename(self):
"""Test writeTree with invalid filename argument."""
with self.assertRaises(TypeError):
writeTree(123, self.upgma_tree)

def testWriteTreeWrongTreeType(self):
"""Test writeTree with invalid tree argument."""
output_file = os.path.join(self.temp_dir, 'test.nwk')
with self.assertRaises(TypeError):
writeTree(output_file, "not a tree")

def testWriteTreeWrongFormat(self):
"""Test writeTree with invalid format argument."""
output_file = os.path.join(self.temp_dir, 'test.nwk')
with self.assertRaises(TypeError):
writeTree(output_file, self.upgma_tree, format_str=123)

def testWriteAndParseRoundtrip(self):
"""Test writing a tree and then parsing it back."""
output_file = os.path.join(self.temp_dir, 'roundtrip.nwk')
try:
# Write the tree
writeTree(output_file, self.upgma_tree)
# Parse it back
parsed_tree = parseTree(output_file)
# Verify the parsed tree is valid
self.assertIsNotNone(parsed_tree)
leaves = parsed_tree.get_terminals()
self.assertEqual(len(leaves), 4)
self.assertEqual(len(parsed_tree.root.clades), 2)
except ImportError:
self.skipTest("Biopython not available")

if __name__ == '__main__':
unittest.main()
10 changes: 7 additions & 3 deletions prody/utilities/TreeConstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
import itertools
import copy
import numbers
from Bio.Phylo import BaseTree
from Bio.Phylo.TreeConstruction import _DistanceMatrix as DistanceMatrix

__all__ = ['_Matrix', 'DistanceMatrix', 'TreeConstructor', 'DistanceTreeConstructor']
__all__ = ['_Matrix', 'TreeConstructor', 'DistanceTreeConstructor']

class _Matrix:
"""Base class for distance matrix or scoring matrix.
Expand Down Expand Up @@ -400,6 +398,9 @@ def upgma(self, distance_matrix):
The distance matrix for tree construction.

"""
from Bio.Phylo import BaseTree
from Bio.Phylo.TreeConstruction import _DistanceMatrix as DistanceMatrix

if not isinstance(distance_matrix, DistanceMatrix):
raise TypeError("Must provide a DistanceMatrix object.")

Expand Down Expand Up @@ -456,6 +457,9 @@ def nj(self, distance_matrix):
The distance matrix for tree construction.

"""
from Bio.Phylo import BaseTree
from Bio.Phylo.TreeConstruction import _DistanceMatrix as DistanceMatrix

if not isinstance(distance_matrix, DistanceMatrix):
raise TypeError("Must provide a DistanceMatrix object.")

Expand Down
33 changes: 29 additions & 4 deletions prody/utilities/catchall.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from .logger import LOGGER


__all__ = ['calcTree', 'clusterMatrix',
__all__ = ['calcTree', 'writeTree', 'parseTree',
'clusterMatrix',
'showLines', 'showMatrix', 'showBars',
'reorderMatrix', 'findSubgroups', 'getCoords',
'getLinkage', 'getTreeFromLinkage', 'clusterSubfamilies',
Expand Down Expand Up @@ -239,7 +240,7 @@ def getTreeFromLinkage(names, linkage):
:arg linkage: linkage matrix
:type linkage: :class:`~numpy.ndarray`
"""
try:
try:
from Bio.Phylo.BaseTree import Tree, Clade
except ImportError:
raise ImportError('Phylo module could not be imported. '
Expand Down Expand Up @@ -308,7 +309,8 @@ def calcTree(names, distance_matrix, method='upgma', linkage=False):
:type linkage: bool
"""

from .TreeConstruction import DistanceMatrix, DistanceTreeConstructor
from .TreeConstruction import DistanceTreeConstructor
from Bio.Phylo.TreeConstruction import _DistanceMatrix as DistanceMatrix

if len(names) != distance_matrix.shape[0] or len(names) != distance_matrix.shape[1]:
raise ValueError("Mismatch between the sizes of matrix and names.")
Expand Down Expand Up @@ -366,7 +368,7 @@ def writeTree(filename, tree, format_str='newick'):
:arg format_str: a string specifying the format for the tree
:type format_str: str
"""
try:
try:
from Bio import Phylo
except ImportError:
raise ImportError('Phylo module could not be imported. '
Expand All @@ -384,6 +386,29 @@ def writeTree(filename, tree, format_str='newick'):

Phylo.write(tree, filename, format_str)

def parseTree(filename, format_str='newick'):
""" Parse a tree from a file using Biopython.

:arg filename: name for output file
:type filename: str

:arg format_str: a string specifying the format for the tree
:type format_str: str
"""
try:
from Bio import Phylo
except ImportError:
raise ImportError('Phylo module could not be imported. '
'Reinstall ProDy or install Biopython '
'to solve the problem.')

if not isinstance(filename, str):
raise TypeError('filename should be a string')

if not isinstance(format_str, str):
raise TypeError('format_str should be a string')

return Phylo.read(filename, format_str)

def clusterMatrix(distance_matrix=None, similarity_matrix=None, labels=None, return_linkage=None, **kwargs):
"""
Expand Down
Loading