diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index e13b814072688..9c3d806afa8fb 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -209,17 +209,26 @@ def __init__(self, ar): self.array = ar @staticmethod - def parse(vectorString): + def parse(s): """ Parse string representation back into the DenseVector. >>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]') DenseVector([0.0, 1.0, 2.0, 3.0]) """ - start = vectorString.find('[') - end = vectorString.find(']') - vectorString = vectorString[start + 1: end] - return DenseVector([float(val) for val in vectorString.split(',')]) + start = s.find('[') + if start == -1: + raise ValueError("Array should start with '['") + end = s.find(']') + if end == -1: + raise ValueError("Array should end with ']") + s = s[start + 1: end] + + try: + values = [float(val) for val in s.split(',')] + except ValueError: + raise ValueError("Unable to parse values.") + return DenseVector(values) def __reduce__(self): return DenseVector, (self.array.tostring(),) @@ -436,28 +445,51 @@ def __reduce__(self): (self.size, self.indices.tostring(), self.values.tostring())) @staticmethod - def parse(vectorString): + def parse(s): """ Parse string representation back into the DenseVector. >>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )') SparseVector(4, {0: 4.0, 1: 5.0}) """ - start = vectorString.find('(') - end = vectorString.find(')') - vectorString = vectorString[start+1: end].strip() - size = int(vectorString[0]) - - ind_start = vectorString.find('[') - ind_end = vectorString.find(']') - ind_list = vectorString[ind_start + 1: ind_end].split(',') - indices = [int(ind) for ind in ind_list] - vectorString = vectorString[ind_end + 1:].strip() - - val_start = vectorString.find('[') - val_end = vectorString.find(']') - val_list = vectorString[val_start + 1: val_end].split(',') - values = [float(val) for val in val_list] + start = s.find('(') + if start == -1: + raise ValueError("Tuple should start with '('") + end = s.find(')') + if start == -1: + raise ValueError("Tuple should end with ')'") + s = s[start + 1: end].strip() + + size = s[: s.find(',')] + try: + size = int(size) + except ValueError: + raise ValueError("Cannot parse size %s." % size) + + ind_start = s.find('[') + if ind_start == -1: + raise ValueError("Indices array should start with '('.") + ind_end = s.find(']') + if ind_end == -1: + raise ValueError("Indices array should end with ')'") + ind_list = s[ind_start + 1: ind_end].split(',') + try: + indices = [int(ind) for ind in ind_list] + except ValueError: + raise ValueError("Unabel to parse indices.") + s = s[ind_end + 1:].strip() + + val_start = s.find('[') + if val_start == -1: + raise ValueError("Values array should start with '('.") + val_end = s.find(']') + if val_end == -1: + raise ValueError("Values array should end with ')'.") + val_list = s[val_start + 1: val_end].split(',') + try: + values = [float(val) for val in val_list] + except ValueError: + raise ValueError("Unable to parse values.") return SparseVector(size, indices, values) def dot(self, other): @@ -704,7 +736,7 @@ def stringify(vector): return str(vector) @staticmethod - def squared_distance(a, b): + def squared_distance(v1, v2): """ Squared distance between two vectors. a and b can be of type SparseVector, DenseVector, np.ndarray @@ -715,25 +747,36 @@ def squared_distance(a, b): >>> a.squared_distance(b) 51.0 """ - a, b = _convert_to_vector(a), _convert_to_vector(b) - return a.squared_distance(b) + v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2) + return v1.squared_distance(v2) @staticmethod - def norm(vec, p): + def norm(vector, p): """ Find norm of the given vector. """ - return _convert_to_vector(vec).norm(p) + return _convert_to_vector(vector).norm(p) @staticmethod - def parse(vectorString): - if vectorString.find('(') == -1: - return DenseVector.parse(vectorString) - return SparseVector.parse(vectorString) + def parse(s): + """Parse a string representation back into the Vector. + + >>> Vectors.parse('[2,1,2 ]') + DenseVector([2.0, 1.0, 2.0]) + >>> Vectors.parse(' ( 100, [0], [2])') + SparseVector(100, {0: 2.0}) + """ + if s.find('(') == -1 and s.find('[') != -1: + return DenseVector.parse(s) + elif s.find('(') != -1: + return SparseVector.parse(s) + else: + raise ValueError( + "Cannot find tokens '[' or '(' from the input string.") @staticmethod - def zeros(num): - return DenseVector(np.zeros(num)) + def zeros(size): + return DenseVector(np.zeros(size)) class Matrix(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 2dfd62676ab2e..36a4c7a5408c6 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -227,6 +227,8 @@ def test_parse_vector(self): a = SparseVector(4, [0, 2], [3, 4]) self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])') self.assertTrue(Vectors.parse(str(a)), a) + a = SparseVector(10, [0, 1], [4, 5]) + self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) def test_norms(self): a = DenseVector([0, 2, 3, -1])