-
Notifications
You must be signed in to change notification settings - Fork 33
/
base.py
450 lines (387 loc) · 16.6 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
# -*- coding: utf-8 -*-
# Copyright 2012-2014 John Whitlock
# Copyright 2014 Juha Yrjölä
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import unicode_literals
from codecs import BOM_UTF8
from collections import defaultdict
from csv import reader, writer
from datetime import datetime, date
from logging import getLogger
import re
from django.contrib.gis.db import models
from django.contrib.gis.db.models.query import GeoQuerySet
from django.db.models.fields.related import ManyToManyField
from django.utils.six import StringIO, text_type, PY3
from multigtfs.compat import get_blank_value
logger = getLogger(__name__)
re_point = re.compile(r'(?P<name>point)\[(?P<index>\d)\]')
batch_size = 1000
large_queryset_size = 100000
class BaseQuerySet(GeoQuerySet):
def populated_column_map(self):
'''Return the _column_map without unused optional fields'''
column_map = []
cls = self.model
for csv_name, field_pattern in cls._column_map:
# Separate the local field name from foreign columns
if '__' in field_pattern:
field_name = field_pattern.split('__', 1)[0]
else:
field_name = field_pattern
# Handle point fields
point_match = re_point.match(field_name)
if point_match:
field = None
else:
field = cls._meta.get_field(field_name)
# Only add optional columns if they are used in the records
if field and field.blank and not field.has_default():
kwargs = {field_name: get_blank_value(field)}
if self.exclude(**kwargs).exists():
column_map.append((csv_name, field_pattern))
else:
column_map.append((csv_name, field_pattern))
return column_map
class BaseManager(models.GeoManager):
def get_queryset(self):
'''Django 1.8 expects this method name. Simply calling the other
method results in a recursion error in some python interpretters.
'''
return BaseQuerySet(self.model)
def get_query_set(self):
return BaseQuerySet(self.model)
def in_feed(self, feed):
'''Return the objects in the target feed'''
kwargs = {self.model._rel_to_feed: feed}
return self.filter(**kwargs)
class Base(models.Model):
"""Base class for models that are defined in the GTFS spec
Implementers need to define a class variable:
_column_map - A mapping of GTFS columns to model fields
It should be set to a sequence of tuples:
- GTFS column name
- Model field name
If the column is optional, then set blank=True on the field, and set
null=True appropriately.
Implementers can define this class variable:
_rel_to_feed - The relation of this model to the field, in Django filter
format. The default is 'feed', and will be used to get the objects
on a feed like this:
Model.objects.filter(_rel_to_feed=feed)
"""
class Meta:
abstract = True
app_label = 'multigtfs'
objects = BaseManager()
# The relation of the model to the feed it belongs to.
_rel_to_feed = 'feed'
@classmethod
def import_txt(cls, txt_file, feed, filter_func=None):
'''Import from the GTFS text file'''
# Setup the conversion from GTFS to Django Format
# Conversion functions
def no_convert(value): return value
def date_convert(value): return datetime.strptime(value, '%Y%m%d')
def bool_convert(value): return (value == '1')
def char_convert(value): return (value or '')
def null_convert(value): return (value or None)
def point_convert(value): return (value or 0.0)
cache = {}
def default_convert(field):
def get_value_or_default(value):
if value == '' or value is None:
return field.get_default()
else:
return value
return get_value_or_default
def instance_convert(field, feed, rel_name):
def get_instance(value):
if value.strip():
key1 = "{}:{}".format(field.rel.to.__name__, rel_name)
key2 = text_type(value)
# Load existing objects
if key1 not in cache:
pairs = field.rel.to.objects.filter(
**{field.rel.to._rel_to_feed: feed}).values_list(
rel_name, 'id')
cache[key1] = dict((text_type(x), i) for x, i in pairs)
# Create new?
if key2 not in cache[key1]:
kwargs = {
field.rel.to._rel_to_feed: feed,
rel_name: value}
cache[key1][key2] = field.rel.to.objects.create(
**kwargs).id
return cache[key1][key2]
else:
return None
return get_instance
# Check unique fields
column_names = [c for c, _ in cls._column_map]
for unique_field in cls._unique_fields:
assert unique_field in column_names, \
'{} not in {}'.format(unique_field, column_names)
# Map of field_name to converters from GTFS to Django format
val_map = dict()
name_map = dict()
point_map = dict()
for csv_name, field_pattern in cls._column_map:
# Separate the local field name from foreign columns
if '__' in field_pattern:
field_base, rel_name = field_pattern.split('__', 1)
field_name = field_base + '_id'
else:
field_name = field_base = field_pattern
# Use the field name in the name mapping
name_map[csv_name] = field_name
# Is it a point field?
point_match = re_point.match(field_name)
if point_match:
field = None
else:
field = cls._meta.get_field(field_base)
# Pick a conversion function for the field
if point_match:
converter = point_convert
elif isinstance(field, models.DateField):
converter = date_convert
elif isinstance(field, models.BooleanField):
converter = bool_convert
elif isinstance(field, models.CharField):
converter = char_convert
elif field.rel:
converter = instance_convert(field, feed, rel_name)
assert not isinstance(field, models.ManyToManyField)
elif field.null:
converter = null_convert
elif field.has_default():
converter = default_convert(field)
else:
converter = no_convert
if point_match:
index = int(point_match.group('index'))
point_map[csv_name] = (index, converter)
else:
val_map[csv_name] = converter
# Read and convert the source txt
csv_reader = reader(txt_file)
unique_line = dict()
count = 0
first = True
extra_counts = defaultdict(int)
if PY3: # pragma: no cover
bom = BOM_UTF8.decode('utf-8')
else: # pragma: no cover
bom = BOM_UTF8
new_objects = []
for row in csv_reader:
if first:
# Read the columns
columns = row
if columns[0].startswith(bom):
columns[0] = columns[0][len(bom):]
first = False
continue
if filter_func and not filter_func(zip(columns, row)):
continue
# Read a data row
fields = dict()
point_coords = [None, None]
ukey_values = {}
if cls._rel_to_feed == 'feed':
fields['feed'] = feed
for column_name, value in zip(columns, row):
if column_name not in name_map:
val = null_convert(value)
if val is not None:
fields.setdefault('extra_data', {})[column_name] = val
extra_counts[column_name] += 1
elif column_name in val_map:
fields[name_map[column_name]] = val_map[column_name](value)
else:
assert column_name in point_map
pos, converter = point_map[column_name]
point_coords[pos] = converter(value)
# Is it part of the unique key?
if column_name in cls._unique_fields:
ukey_values[column_name] = value
# Join the lat/long into a point
if point_map:
assert point_coords[0] and point_coords[1]
fields['point'] = "POINT(%s)" % (' '.join(point_coords))
# Is the item unique?
ukey = tuple(ukey_values.get(u) for u in cls._unique_fields)
if ukey in unique_line:
logger.warning(
'%s line %d is a duplicate of line %d, not imported.',
cls._filename, csv_reader.line_num, unique_line[ukey])
continue
else:
unique_line[ukey] = csv_reader.line_num
# Create after accumulating a batch
new_objects.append(cls(**fields))
if len(new_objects) % batch_size == 0: # pragma: no cover
cls.objects.bulk_create(new_objects)
count += len(new_objects)
logger.info(
"Imported %d %s",
count, cls._meta.verbose_name_plural)
new_objects = []
# Create remaining objects
if new_objects:
cls.objects.bulk_create(new_objects)
# Take note of extra fields
if extra_counts:
extra_columns = feed.meta.setdefault(
'extra_columns', {}).setdefault(cls.__name__, [])
for column in columns:
if column in extra_counts and column not in extra_columns:
extra_columns.append(column)
feed.save()
return len(unique_line)
@classmethod
def export_txt(cls, feed):
'''Export records as a GTFS comma-separated file'''
objects = cls.objects.in_feed(feed)
# If no records, return None
if not objects.exists():
return
# Get the columns used in the dataset
column_map = objects.populated_column_map()
columns, fields = zip(*column_map)
extra_columns = feed.meta.get(
'extra_columns', {}).get(cls.__name__, [])
# Get sort order
if hasattr(cls, '_sort_order'):
sort_fields = cls._sort_order
else:
sort_fields = []
for field in fields:
base_field = field.split('__', 1)[0]
point_match = re_point.match(base_field)
if point_match:
continue
field_type = cls._meta.get_field(base_field)
assert not isinstance(field_type, ManyToManyField)
sort_fields.append(field)
# Create CSV writer
out = StringIO()
csv_writer = writer(out, lineterminator='\n')
# Write header row
header_row = [text_type(c) for c in columns]
header_row.extend(extra_columns)
write_rows(csv_writer, [header_row])
# Report the work to be done
total = objects.count()
logger.info(
'%d %s to export...',
total, cls._meta.verbose_name_plural)
# Populate related items cache
model_to_field_name = {}
cache = {}
for field_name in fields:
if '__' in field_name:
local_field_name, subfield_name = field_name.split('__', 1)
field = cls._meta.get_field(local_field_name)
field_type = field.rel.to
model_name = field_type.__name__
if model_name in model_to_field_name:
# Already loaded this model under a different field name
cache[field_name] = cache[model_to_field_name[model_name]]
else:
# Load all feed data for this model
pairs = field_type.objects.in_feed(
feed).values_list('id', subfield_name)
cache[field_name] = dict(
(i, text_type(x)) for i, x in pairs)
cache[field_name][None] = u''
model_to_field_name[model_name] = field_name
# For large querysets, break up by the first field
if total < large_queryset_size:
querysets = [objects.order_by(*sort_fields)]
else: # pragma: no cover
field1_raw = sort_fields[0]
assert '__' in field1_raw
assert field1_raw in cache
field1 = field1_raw.split('__', 1)[0]
field1_id = field1 + '_id'
# Sort field1 ids by field1 values
val_to_id = dict((v, k) for k, v in cache[field1_raw].items())
assert len(val_to_id) == len(cache[field1_raw])
sorted_vals = sorted(val_to_id.keys())
querysets = []
for val in sorted_vals:
fid = val_to_id[val]
if fid:
qs = objects.filter(
**{field1_id: fid}).order_by(*sort_fields[1:])
querysets.append(qs)
# Assemble the rows, writing when we hit batch size
count = 0
rows = []
for queryset in querysets:
for item in queryset.order_by(*sort_fields):
row = []
for csv_name, field_name in column_map:
obj = item
point_match = re_point.match(field_name)
if '__' in field_name:
# Return relations from cache
local_field_name = field_name.split('__', 1)[0]
field_id = getattr(obj, local_field_name + '_id')
row.append(cache[field_name][field_id])
elif point_match:
# Get the lat or long from the point
name, index = point_match.groups()
field = getattr(obj, name)
row.append(field.coords[int(index)])
else:
# Handle other field types
field = getattr(obj, field_name) if obj else ''
if isinstance(field, date):
formatted = field.strftime(u'%Y%m%d')
row.append(text_type(formatted))
elif isinstance(field, bool):
row.append(1 if field else 0)
elif field is None:
row.append(u'')
else:
row.append(text_type(field))
for col in extra_columns:
row.append(obj.extra_data.get(col, u''))
rows.append(row)
if len(rows) % batch_size == 0: # pragma: no cover
write_rows(csv_writer, rows)
count += len(rows)
logger.info(
"Exported %d %s",
count, cls._meta.verbose_name_plural)
rows = []
# Write rows smaller than batch size
write_rows(csv_writer, rows)
return out.getvalue()
def write_rows(writer, rows):
'''Write a batch of row data to the csv writer'''
for row in rows:
try:
writer.writerow(row)
except UnicodeEncodeError: # pragma: no cover
# Python 2 csv does badly with unicode outside of ASCII
new_row = []
for item in row:
if isinstance(item, text_type):
new_row.append(item.encode('utf-8'))
else:
new_row.append(item)
writer.writerow(new_row)