diff --git a/cluster/util.py b/cluster/util.py index 461d1ac..6b75299 100644 --- a/cluster/util.py +++ b/cluster/util.py @@ -168,9 +168,30 @@ def dotproduct(a, b): return out -def centroid(data, method=median): - "returns the central vector of a list of vectors" +def centroid(data, method=median, getter=None, cumulator=None): + """ + returns the central vector of a list of vectors. + + :param data: The container of values. + :param method: A accumulation function (usually ``median`` or ``mean``) + :param getter: A method to access the data field of custom objects. + """ out = [] + if not data: + return None + + if getter and cumulator: + raise ValueError('You should only supply either a cumulator or getter ' + 'function! Not both!') + + if getter: + extracted_data = [getter(_) for _ in data] + if isinstance(extracted_data[0], int): + return method(extracted_data) + else: + result = (method(x) for x in zip(*[getter(_) for _ in data])) + return tuple(result) + for i in range(len(data[0])): out.append(method([x[i] for x in data])) return tuple(out) diff --git a/test.py b/test.py index af096eb..ae3e27f 100644 --- a/test.py +++ b/test.py @@ -15,15 +15,19 @@ # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA # -from cluster import (HierarchicalClustering, KMeansClustering, ClusteringError) from difflib import SequenceMatcher +from os import urandom +from random import randint import unittest + try: import numpy NUMPY_AVAILABLE = True except: NUMPY_AVAILABLE = False +from cluster import (HierarchicalClustering, KMeansClustering, ClusteringError) +import cluster.util as util def compare_list(x, y): """ @@ -46,6 +50,18 @@ def compare_list(x, y): return all_ok +class MyObject(object): + """ + A custom data object used in some tests. + """ + def __init__(self, value, uid=None): + self.value = value + self.uid = uid or urandom(10).encode('base64').strip() + + def __repr__(self): + return 'MyObject({!r}, {!r})'.format(self.value, self.uid) + + class HClusterSmallListTestCase(unittest.TestCase): """ Test for Bug #1516204 @@ -217,6 +233,57 @@ def testMultidimArray(self): cl.getclusters(10) +class KClusterGithubIssues(unittest.TestCase): + + def test_custom_object_data(self): + self.skipTest("temporarily skipped until centroid is solved") + data = [MyObject(randint(0, 1000)) for _ in range(40)] + cl = KMeansClustering(data, lambda x, y: float(abs(x.value-y.value))) + clustered = cl.getclusters(10) + print(clustered) + self.fail() + + +class TestUtils(unittest.TestCase): + + def test_default_centroid(self): + result = util.centroid([ + (1, 2, 3), + (2, 3, 4), + (3, 4, 5), + (4, 5, 6), + ]) + self.assertEqual(result, (2.5, 3.5, 4.5)) + + def test_custom_centroid_fold(self): + result = util.centroid([ + MyObject((1, 2, 3), 1), + MyObject((2, 3, 4), 2), + MyObject((3, 4, 5), 3), + MyObject((4, 5, 6), 4), + ]) + self.assertEqual(result, (2.5, 3.5, 4.5)) + + def test_custom_centroid_getter(self): + result = util.centroid([ + MyObject((1, 2, 3), 1), + MyObject((2, 3, 4), 2), + MyObject((3, 4, 5), 3), + MyObject((4, 5, 6), 4), + ], getter=lambda x: x.value) + self.assertEqual(result, (2.5, 3.5, 4.5)) + + def test_custom2_centroid_getter(self): + result = util.centroid([ + MyObject(1, 1), + MyObject(2, 2), + MyObject(3, 3), + MyObject(4, 4), + ], getter=lambda x: x.value) + self.assertEqual(result, 2.5) + + + @unittest.skipUnless(NUMPY_AVAILABLE, 'numpy not available. Associated test will not be loaded!') class NumpyTests(unittest.TestCase):