From 5d16d5bbfd242c16ee0d6952c48dcd90651f8ae2 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 21 Jul 2014 22:30:53 -0700 Subject: [PATCH] [SPARK-2470] PEP8 fixes to PySpark This pull request aims to resolve all outstanding PEP8 violations in PySpark. Author: Nicholas Chammas Author: nchammas Closes #1505 from nchammas/master and squashes the following commits: 98171af [Nicholas Chammas] [SPARK-2470] revert PEP 8 fixes to cloudpickle cba7768 [Nicholas Chammas] [SPARK-2470] wrap expression list in parentheses e178dbe [Nicholas Chammas] [SPARK-2470] style - change position of line break 9127d2b [Nicholas Chammas] [SPARK-2470] wrap expression lists in parentheses 22132a4 [Nicholas Chammas] [SPARK-2470] wrap conditionals in parentheses 24639bc [Nicholas Chammas] [SPARK-2470] fix whitespace for doctest 7d557b7 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to tests.py 8f8e4c0 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to storagelevel.py b3b96cf [Nicholas Chammas] [SPARK-2470] PEP8 fixes to statcounter.py d644477 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to worker.py aa3a7b6 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to sql.py 1916859 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to shell.py 95d1d95 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to serializers.py a0fec2e [Nicholas Chammas] [SPARK-2470] PEP8 fixes to mllib c85e1e5 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to join.py d14f2f1 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to __init__.py 81fcb20 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to resultiterable.py 1bde265 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to java_gateway.py 7fc849c [Nicholas Chammas] [SPARK-2470] PEP8 fixes to daemon.py ca2d28b [Nicholas Chammas] [SPARK-2470] PEP8 fixes to context.py f4e0039 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to conf.py a6d5e4b [Nicholas Chammas] [SPARK-2470] PEP8 fixes to cloudpickle.py f0a7ebf [Nicholas Chammas] [SPARK-2470] PEP8 fixes to rddsampler.py 4dd148f [nchammas] Merge pull request #5 from apache/master f7e4581 [Nicholas Chammas] unrelated pep8 fix a36eed0 [Nicholas Chammas] name ec2 instances and security groups consistently de7292a [nchammas] Merge pull request #4 from apache/master 2e4fe00 [nchammas] Merge pull request #3 from apache/master 89fde08 [nchammas] Merge pull request #2 from apache/master 69f6e22 [Nicholas Chammas] PEP8 fixes 2627247 [Nicholas Chammas] broke up lines before they hit 100 chars 6544b7e [Nicholas Chammas] [SPARK-2065] give launched instances names 69da6cf [nchammas] Merge pull request #1 from apache/master --- python/pyspark/__init__.py | 3 ++- python/pyspark/conf.py | 9 ++++--- python/pyspark/context.py | 45 ++++++++++++++++++-------------- python/pyspark/daemon.py | 12 ++++----- python/pyspark/java_gateway.py | 1 + python/pyspark/join.py | 4 ++- python/pyspark/mllib/_common.py | 4 ++- python/pyspark/mllib/linalg.py | 1 + python/pyspark/mllib/util.py | 2 -- python/pyspark/rddsampler.py | 24 ++++++++--------- python/pyspark/resultiterable.py | 3 +++ python/pyspark/serializers.py | 31 +++++++++++++--------- python/pyspark/shell.py | 3 ++- python/pyspark/sql.py | 38 +++++++++++++++------------ python/pyspark/statcounter.py | 25 +++++++++--------- python/pyspark/storagelevel.py | 5 ++-- python/pyspark/tests.py | 10 ++++--- python/pyspark/worker.py | 4 +-- 18 files changed, 127 insertions(+), 97 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 07df8697bd1a8..312c75d112cbf 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -59,4 +59,5 @@ from pyspark.storagelevel import StorageLevel -__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", "SparkFiles", "StorageLevel", "Row"] +__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", + "SparkFiles", "StorageLevel", "Row"] diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 60fc6ba7c52c2..b50590ab3b444 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -50,7 +50,8 @@ spark.executorEnv.VAR4=value4 spark.home=/path >>> sorted(conf.getAll(), key=lambda p: p[0]) -[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] +[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ +(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] """ @@ -118,9 +119,9 @@ def setExecutorEnv(self, key=None, value=None, pairs=None): """Set an environment variable to be passed to executors.""" if (key is not None and pairs is not None) or (key is None and pairs is None): raise Exception("Either pass one key-value pair or a list of pairs") - elif key != None: + elif key is not None: self._jconf.setExecutorEnv(key, value) - elif pairs != None: + elif pairs is not None: for (k, v) in pairs: self._jconf.setExecutorEnv(k, v) return self @@ -137,7 +138,7 @@ def setAll(self, pairs): def get(self, key, defaultValue=None): """Get the configured value for some key, or return a default otherwise.""" - if defaultValue == None: # Py4J doesn't call the right get() if we pass None + if defaultValue is None: # Py4J doesn't call the right get() if we pass None if not self._jconf.contains(key): return None return self._jconf.get(key) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 95c54e7a5ad63..e21be0e10a3f7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer + PairDeserializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -50,12 +50,11 @@ class SparkContext(object): _next_accum_id = 0 _active_spark_context = None _lock = Lock() - _python_includes = None # zip and egg files that need to be added to PYTHONPATH - + _python_includes = None # zip and egg files that need to be added to PYTHONPATH def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, - gateway=None): + environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + gateway=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -138,8 +137,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, self._accumulatorServer = accumulators._start_update_server() (host, port) = self._accumulatorServer.server_address self._javaAccumulator = self._jsc.accumulator( - self._jvm.java.util.ArrayList(), - self._jvm.PythonAccumulatorParam(host, port)) + self._jvm.java.util.ArrayList(), + self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') @@ -165,7 +164,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, (dirname, filename) = os.path.split(path) self._python_includes.append(filename) sys.path.append(path) - if not dirname in sys.path: + if dirname not in sys.path: sys.path.append(dirname) # Create a temporary directory inside spark.local.dir: @@ -192,15 +191,19 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: - if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: + if (SparkContext._active_spark_context and + SparkContext._active_spark_context != instance): currentMaster = SparkContext._active_spark_context.master currentAppName = SparkContext._active_spark_context.appName callsite = SparkContext._active_spark_context._callsite # Raise error if there is already a running Spark context - raise ValueError("Cannot run multiple SparkContexts at once; existing SparkContext(app=%s, master=%s)" \ - " created by %s at %s:%s " \ - % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum)) + raise ValueError( + "Cannot run multiple SparkContexts at once; " + "existing SparkContext(app=%s, master=%s)" + " created by %s at %s:%s " + % (currentAppName, currentMaster, + callsite.function, callsite.file, callsite.linenum)) else: SparkContext._active_spark_context = instance @@ -290,7 +293,7 @@ def textFile(self, name, minPartitions=None): Read a text file from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI, and return it as an RDD of Strings. - + >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: ... testFile.write("Hello world!") @@ -584,11 +587,12 @@ def addPyFile(self, path): HTTP, HTTPS or FTP URI. """ self.addFile(path) - (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix + (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): self._python_includes.append(filename) - sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode + # for tests in local mode + sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) def setCheckpointDir(self, dirName): """ @@ -649,9 +653,9 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): Cancelled If interruptOnCancel is set to true for the job group, then job cancellation will result - in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure - that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, - where HDFS may respond to Thread.interrupt() by marking nodes as dead. + in Thread.interrupt() being called on the job's executor threads. This is useful to help + ensure that the tasks are actually stopped in a timely manner, but is off by default due + to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead. """ self._jsc.setJobGroup(groupId, description, interruptOnCancel) @@ -688,7 +692,7 @@ def cancelAllJobs(self): """ self._jsc.sc().cancelAllJobs() - def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): + def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ Executes the given partitionFunc on the specified set of partitions, returning the result as an array of elements. @@ -703,7 +707,7 @@ def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True) [0, 1, 16, 25] """ - if partitions == None: + if partitions is None: partitions = range(rdd._jrdd.partitions().size()) javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) @@ -714,6 +718,7 @@ def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) + def _test(): import atexit import doctest diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 5eb1c63bf206b..8a5873ded2b8b 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -42,12 +42,12 @@ def should_exit(): def compute_real_exit_code(exit_code): - # SystemExit's code can be integer or string, but os._exit only accepts integers - import numbers - if isinstance(exit_code, numbers.Integral): - return exit_code - else: - return 1 + # SystemExit's code can be integer or string, but os._exit only accepts integers + import numbers + if isinstance(exit_code, numbers.Integral): + return exit_code + else: + return 1 def worker(listen_sock): diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2a17127a7e0f9..2c129679f47f3 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -24,6 +24,7 @@ from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient + def launch_gateway(): SPARK_HOME = os.environ["SPARK_HOME"] diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 5f3a7e71f7866..b0f1cc1927066 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -33,10 +33,11 @@ from pyspark.resultiterable import ResultIterable + def _do_python_join(rdd, other, numPartitions, dispatch): vs = rdd.map(lambda (k, v): (k, (1, v))) ws = other.map(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x : dispatch(x.__iter__())) + return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__())) def python_join(rdd, other, numPartitions): @@ -85,6 +86,7 @@ def make_mapper(i): vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) rdd_len = len(vrdds) + def dispatch(seq): bufs = [[] for i in range(rdd_len)] for (n, v) in seq: diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e609b60a0f968..43b491a9716fc 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -164,7 +164,7 @@ def _deserialize_double_vector(ba, offset=0): nb = len(ba) - offset if nb < 5: raise TypeError("_deserialize_double_vector called on a %d-byte array, " - "which is too short" % nb) + "which is too short" % nb) if ba[offset] == DENSE_VECTOR_MAGIC: return _deserialize_dense_vector(ba, offset) elif ba[offset] == SPARSE_VECTOR_MAGIC: @@ -272,6 +272,7 @@ def _serialize_labeled_point(p): header_float[0] = p.label return header + serialized_features + def _deserialize_labeled_point(ba, offset=0): """Deserialize a LabeledPoint from a mutually understood format.""" from pyspark.mllib.regression import LabeledPoint @@ -283,6 +284,7 @@ def _deserialize_labeled_point(ba, offset=0): features = _deserialize_double_vector(ba, offset + 9) return LabeledPoint(label, features) + def _copyto(array, buffer, offset, shape, dtype): """ Copy the contents of a vector to a destination bytearray at the diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index db39ed0acdb66..71f4ad1a8d44e 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -247,6 +247,7 @@ def stringify(vector): else: return "[" + ",".join([str(v) for v in vector]) + "]" + def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index e24c144f458bd..a707a9dcd5b49 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -24,7 +24,6 @@ from pyspark.serializers import NoOpSerializer - class MLUtils: """ Helper methods to load, save and pre-process data used in MLlib. @@ -154,7 +153,6 @@ def saveAsLibSVMFile(data, dir): lines = data.map(lambda p: MLUtils._convert_labeled_point_to_libsvm(p)) lines.saveAsTextFile(dir) - @staticmethod def loadLabeledPoints(sc, path, minPartitions=None): """ diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 122bc38b03b0c..7ff1c316c7623 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -18,13 +18,16 @@ import sys import random + class RDDSampler(object): def __init__(self, withReplacement, fraction, seed=None): try: import numpy self._use_numpy = True except ImportError: - print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling." + print >> sys.stderr, ( + "NumPy does not appear to be installed. " + "Falling back to default random generator for sampling.") self._use_numpy = False self._seed = seed if seed is not None else random.randint(0, sys.maxint) @@ -61,7 +64,7 @@ def getUniformSample(self, split): def getPoissonSample(self, split, mean): if not self._rand_initialized or split != self._split: self.initRandomGenerator(split) - + if self._use_numpy: return self._random.poisson(mean) else: @@ -80,30 +83,27 @@ def getPoissonSample(self, split, mean): num_arrivals += 1 return (num_arrivals - 1) - + def shuffle(self, vals): if self._random is None: self.initRandomGenerator(0) # this should only ever called on the master so # the split does not matter - + if self._use_numpy: self._random.shuffle(vals) else: self._random.shuffle(vals, self._random.random) def func(self, split, iterator): - if self._withReplacement: + if self._withReplacement: for obj in iterator: - # For large datasets, the expected number of occurrences of each element in a sample with - # replacement is Poisson(frac). We use that to get a count for each element. - count = self.getPoissonSample(split, mean = self._fraction) + # For large datasets, the expected number of occurrences of each element in + # a sample with replacement is Poisson(frac). We use that to get a count for + # each element. + count = self.getPoissonSample(split, mean=self._fraction) for _ in range(0, count): yield obj else: for obj in iterator: if self.getUniformSample(split) <= self._fraction: yield obj - - - - diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index 7f418f8d2e29a..df34740fc8176 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -19,6 +19,7 @@ import collections + class ResultIterable(collections.Iterable): """ A special result iterable. This is used because the standard iterator can not be pickled @@ -27,7 +28,9 @@ def __init__(self, data): self.data = data self.index = 0 self.maxindex = len(data) + def __iter__(self): return iter(self.data) + def __len__(self): return len(self.data) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b253807974a2e..9be78b39fbc21 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -91,7 +91,6 @@ def load_stream(self, stream): """ raise NotImplementedError - def _load_stream_without_unbatching(self, stream): return self.load_stream(stream) @@ -197,8 +196,8 @@ def _load_stream_without_unbatching(self, stream): return self.serializer.load_stream(stream) def __eq__(self, other): - return isinstance(other, BatchedSerializer) and \ - other.serializer == self.serializer + return (isinstance(other, BatchedSerializer) and + other.serializer == self.serializer) def __str__(self): return "BatchedSerializer<%s>" % str(self.serializer) @@ -229,8 +228,8 @@ def load_stream(self, stream): yield pair def __eq__(self, other): - return isinstance(other, CartesianDeserializer) and \ - self.key_ser == other.key_ser and self.val_ser == other.val_ser + return (isinstance(other, CartesianDeserializer) and + self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __str__(self): return "CartesianDeserializer<%s, %s>" % \ @@ -252,18 +251,20 @@ def load_stream(self, stream): yield pair def __eq__(self, other): - return isinstance(other, PairDeserializer) and \ - self.key_ser == other.key_ser and self.val_ser == other.val_ser + return (isinstance(other, PairDeserializer) and + self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __str__(self): - return "PairDeserializer<%s, %s>" % \ - (str(self.key_ser), str(self.val_ser)) + return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): - def loads(self, obj): return obj - def dumps(self, obj): return obj + def loads(self, obj): + return obj + + def dumps(self, obj): + return obj class PickleSerializer(FramedSerializer): @@ -276,12 +277,16 @@ class PickleSerializer(FramedSerializer): not be as fast as more specialized serializers. """ - def dumps(self, obj): return cPickle.dumps(obj, 2) + def dumps(self, obj): + return cPickle.dumps(obj, 2) + loads = cPickle.loads + class CloudPickleSerializer(PickleSerializer): - def dumps(self, obj): return cloudpickle.dumps(obj, 2) + def dumps(self, obj): + return cloudpickle.dumps(obj, 2) class MarshalSerializer(FramedSerializer): diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 2ce5409cd67c2..e1e7cd954189f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -35,7 +35,8 @@ from pyspark.storagelevel import StorageLevel # this is the equivalent of ADD_JARS -add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") is not None else None +add_files = (os.environ.get("ADD_FILES").split(',') + if os.environ.get("ADD_FILES") is not None else None) if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ffe177576f363..cb83e89176823 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -30,7 +30,7 @@ class SQLContext: tables, execute SQL over tables, cache tables, and read parquet files. """ - def __init__(self, sparkContext, sqlContext = None): + def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. @param sparkContext: The SparkContext to wrap. @@ -137,7 +137,6 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): """Loads a text file storing one JSON object per line, returning the result as a L{SchemaRDD}. @@ -234,8 +233,8 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ - "sbt/sbt assembly" , e) + raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " + "sbt/sbt assembly", e) def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) @@ -377,7 +376,7 @@ def registerAsTable(self, name): """ self._jschema_rdd.registerAsTable(name) - def insertInto(self, tableName, overwrite = False): + def insertInto(self, tableName, overwrite=False): """Inserts the contents of this SchemaRDD into the specified table. Optionally overwriting any existing data. @@ -420,7 +419,7 @@ def _toPython(self): # in Java land in the javaToPython function. May require a custom # pickle serializer in Pyrolite return RDD(jrdd, self._sc, BatchedSerializer( - PickleSerializer())).map(lambda d: Row(d)) + PickleSerializer())).map(lambda d: Row(d)) # We override the default cache/persist/checkpoint behavior as we want to cache the underlying # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class @@ -483,6 +482,7 @@ def subtract(self, other, numPartitions=None): else: raise ValueError("Can only subtract another SchemaRDD") + def _test(): import doctest from array import array @@ -493,20 +493,25 @@ def _test(): sc = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) - globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, - {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) - jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'] + globs['rdd'] = sc.parallelize( + [{"field1": 1, "field2": "row1"}, + {"field1": 2, "field2": "row2"}, + {"field1": 3, "field2": "row3"}] + ) + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}' + ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) globs['nestedRdd1'] = sc.parallelize([ - {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) + {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, + {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) - (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) + {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, + {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) @@ -514,4 +519,3 @@ def _test(): if __name__ == "__main__": _test() - diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 080325061a697..e287bd3da1f61 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -20,18 +20,19 @@ import copy import math + class StatCounter(object): - + def __init__(self, values=[]): self.n = 0L # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) self.maxValue = float("-inf") self.minValue = float("inf") - + for v in values: self.merge(v) - + # Add a value into this StatCounter, updating the internal statistics. def merge(self, value): delta = value - self.mu @@ -42,7 +43,7 @@ def merge(self, value): self.maxValue = value if self.minValue > value: self.minValue = value - + return self # Merge another StatCounter into this one, adding up the internal statistics. @@ -50,7 +51,7 @@ def mergeStats(self, other): if not isinstance(other, StatCounter): raise Exception("Can only merge Statcounters!") - if other is self: # reference equality holds + if other is self: # reference equality holds self.merge(copy.deepcopy(other)) # Avoid overwriting fields in a weird order else: if self.n == 0: @@ -59,8 +60,8 @@ def mergeStats(self, other): self.n = other.n self.maxValue = other.maxValue self.minValue = other.minValue - - elif other.n != 0: + + elif other.n != 0: delta = other.mu - self.mu if other.n * 10 < self.n: self.mu = self.mu + (delta * other.n) / (self.n + other.n) @@ -68,10 +69,10 @@ def mergeStats(self, other): self.mu = other.mu - (delta * self.n) / (self.n + other.n) else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) - + self.maxValue = max(self.maxValue, other.maxValue) self.minValue = min(self.minValue, other.minValue) - + self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n return self @@ -94,7 +95,7 @@ def min(self): def max(self): return self.maxValue - + # Return the variance of the values. def variance(self): if self.n == 0: @@ -124,5 +125,5 @@ def sampleStdev(self): return math.sqrt(self.sampleVariance()) def __repr__(self): - return "(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min()) - + return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % + (self.count(), self.mean(), self.stdev(), self.max(), self.min())) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 3a18ea54eae4c..5d77a131f2856 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -17,6 +17,7 @@ __all__ = ["StorageLevel"] + class StorageLevel: """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -25,7 +26,7 @@ class StorageLevel: Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. """ - def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication = 1): + def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication=1): self.useDisk = useDisk self.useMemory = useMemory self.useOffHeap = useOffHeap @@ -55,4 +56,4 @@ def __str__(self): StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2) StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False) StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2) -StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) \ No newline at end of file +StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c15bb457759ed..9c5ecd0bb02ab 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -52,12 +52,13 @@ class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name , batchSize=2) + self.sc = SparkContext('local[4]', class_name, batchSize=2) def tearDown(self): self.sc.stop() sys.path = self._old_sys_path + class TestCheckpoint(PySparkTestCase): def setUp(self): @@ -190,6 +191,7 @@ def test_deleting_input_files(self): def testAggregateByKey(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + def seqOp(x, y): x.add(y) return x @@ -197,17 +199,19 @@ def seqOp(x, y): def combOp(x, y): x |= y return x - + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) self.assertEqual(3, len(sets)) self.assertEqual(set([1]), sets[1]) self.assertEqual(set([2]), sets[3]) self.assertEqual(set([1, 3]), sets[5]) + class TestIO(PySparkTestCase): def test_stdout_redirection(self): import subprocess + def func(x): subprocess.check_call('ls', shell=True) self.sc.parallelize([1]).foreach(func) @@ -479,7 +483,7 @@ def test_module_dependency(self): | return x + 1 """) proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], - stdout=subprocess.PIPE) + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f43210c6c0301..24d41b12d1b1a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,8 +57,8 @@ def main(infile, outfile): SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here - num_python_includes = read_int(infile) + sys.path.append(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename))