Skip to content

Commit

Permalink
Add error message for parser
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed May 7, 2015
1 parent 1bd3c04 commit ce3e53e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 32 deletions.
107 changes: 75 additions & 32 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,26 @@ def __init__(self, ar):
self.array = ar

@staticmethod
def parse(vectorString):
def parse(s):
"""
Parse string representation back into the DenseVector.
>>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]')
DenseVector([0.0, 1.0, 2.0, 3.0])
"""
start = vectorString.find('[')
end = vectorString.find(']')
vectorString = vectorString[start + 1: end]
return DenseVector([float(val) for val in vectorString.split(',')])
start = s.find('[')
if start == -1:
raise ValueError("Array should start with '['")
end = s.find(']')
if end == -1:
raise ValueError("Array should end with ']")
s = s[start + 1: end]

try:
values = [float(val) for val in s.split(',')]
except ValueError:
raise ValueError("Unable to parse values.")
return DenseVector(values)

def __reduce__(self):
return DenseVector, (self.array.tostring(),)
Expand Down Expand Up @@ -436,28 +445,51 @@ def __reduce__(self):
(self.size, self.indices.tostring(), self.values.tostring()))

@staticmethod
def parse(vectorString):
def parse(s):
"""
Parse string representation back into the DenseVector.
>>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )')
SparseVector(4, {0: 4.0, 1: 5.0})
"""
start = vectorString.find('(')
end = vectorString.find(')')
vectorString = vectorString[start+1: end].strip()
size = int(vectorString[0])

ind_start = vectorString.find('[')
ind_end = vectorString.find(']')
ind_list = vectorString[ind_start + 1: ind_end].split(',')
indices = [int(ind) for ind in ind_list]
vectorString = vectorString[ind_end + 1:].strip()

val_start = vectorString.find('[')
val_end = vectorString.find(']')
val_list = vectorString[val_start + 1: val_end].split(',')
values = [float(val) for val in val_list]
start = s.find('(')
if start == -1:
raise ValueError("Tuple should start with '('")
end = s.find(')')
if start == -1:
raise ValueError("Tuple should end with ')'")
s = s[start + 1: end].strip()

size = s[: s.find(',')]
try:
size = int(size)
except ValueError:
raise ValueError("Cannot parse size %s." % size)

ind_start = s.find('[')
if ind_start == -1:
raise ValueError("Indices array should start with '('.")
ind_end = s.find(']')
if ind_end == -1:
raise ValueError("Indices array should end with ')'")
ind_list = s[ind_start + 1: ind_end].split(',')
try:
indices = [int(ind) for ind in ind_list]
except ValueError:
raise ValueError("Unabel to parse indices.")
s = s[ind_end + 1:].strip()

val_start = s.find('[')
if val_start == -1:
raise ValueError("Values array should start with '('.")
val_end = s.find(']')
if val_end == -1:
raise ValueError("Values array should end with ')'.")
val_list = s[val_start + 1: val_end].split(',')
try:
values = [float(val) for val in val_list]
except ValueError:
raise ValueError("Unable to parse values.")
return SparseVector(size, indices, values)

def dot(self, other):
Expand Down Expand Up @@ -704,7 +736,7 @@ def stringify(vector):
return str(vector)

@staticmethod
def squared_distance(a, b):
def squared_distance(v1, v2):
"""
Squared distance between two vectors.
a and b can be of type SparseVector, DenseVector, np.ndarray
Expand All @@ -715,25 +747,36 @@ def squared_distance(a, b):
>>> a.squared_distance(b)
51.0
"""
a, b = _convert_to_vector(a), _convert_to_vector(b)
return a.squared_distance(b)
v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2)
return v1.squared_distance(v2)

@staticmethod
def norm(vec, p):
def norm(vector, p):
"""
Find norm of the given vector.
"""
return _convert_to_vector(vec).norm(p)
return _convert_to_vector(vector).norm(p)

@staticmethod
def parse(vectorString):
if vectorString.find('(') == -1:
return DenseVector.parse(vectorString)
return SparseVector.parse(vectorString)
def parse(s):
"""Parse a string representation back into the Vector.
>>> Vectors.parse('[2,1,2 ]')
DenseVector([2.0, 1.0, 2.0])
>>> Vectors.parse(' ( 100, [0], [2])')
SparseVector(100, {0: 2.0})
"""
if s.find('(') == -1 and s.find('[') != -1:
return DenseVector.parse(s)
elif s.find('(') != -1:
return SparseVector.parse(s)
else:
raise ValueError(
"Cannot find tokens '[' or '(' from the input string.")

@staticmethod
def zeros(num):
return DenseVector(np.zeros(num))
def zeros(size):
return DenseVector(np.zeros(size))


class Matrix(object):
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def test_parse_vector(self):
a = SparseVector(4, [0, 2], [3, 4])
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
self.assertTrue(Vectors.parse(str(a)), a)
a = SparseVector(10, [0, 1], [4, 5])
self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)

def test_norms(self):
a = DenseVector([0, 2, 3, -1])
Expand Down

0 comments on commit ce3e53e

Please sign in to comment.