Skip to content

Commit

Permalink
Remove query kind when loading if present
Browse files Browse the repository at this point in the history
  • Loading branch information
greenape committed Apr 15, 2020
1 parent b6ac79d commit 0d726b0
Show file tree
Hide file tree
Showing 36 changed files with 556 additions and 620 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from marshmallow import fields, post_load, Schema
from marshmallow import fields
from marshmallow.validate import OneOf

from flowmachine.features.location.active_at_reference_location_counts import (
Expand All @@ -22,22 +22,13 @@
"ActiveAtReferenceLocationCountsExposed",
]

from .base_schema import BaseSchema

from .reference_location import ReferenceLocationSchema

from .unique_locations import UniqueLocationsSchema


class ActiveAtReferenceLocationCountsSchema(Schema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["active_at_reference_location_counts"]))
unique_locations = fields.Nested(UniqueLocationsSchema())
reference_locations = fields.Nested(ReferenceLocationSchema())

@post_load
def make_query_object(self, params, **kwargs):
return ActiveAtReferenceLocationCountsExposed(**params)


class ActiveAtReferenceLocationCountsExposed(BaseExposedQuery):
def __init__(
self, unique_locations, reference_locations,
Expand All @@ -64,3 +55,12 @@ def _flowmachine_query_obj(self):
),
)
)


class ActiveAtReferenceLocationCountsSchema(BaseSchema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["active_at_reference_location_counts"]))
unique_locations = fields.Nested(UniqueLocationsSchema())
reference_locations = fields.Nested(ReferenceLocationSchema())

__model__ = ActiveAtReferenceLocationCountsExposed
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,18 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from marshmallow import Schema, fields, post_load
from marshmallow.validate import OneOf, Length
from marshmallow_oneofschema import OneOfSchema
from marshmallow import fields
from marshmallow.validate import OneOf

from flowmachine.features import AggregateNetworkObjects
from .base_exposed_query import BaseExposedQuery
from .total_network_objects import TotalNetworkObjectsSchema, TotalNetworkObjectsExposed
from .base_schema import BaseSchema
from .total_network_objects import TotalNetworkObjectsSchema
from .custom_fields import Statistic, AggregateBy

__all__ = ["AggregateNetworkObjectsSchema", "AggregateNetworkObjectsExposed"]


class InputToAggregateNetworkObjectsSchema(OneOfSchema):
type_field = "query_kind"
type_schemas = {"total_network_objects": TotalNetworkObjectsSchema}


class AggregateNetworkObjectsSchema(Schema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["aggregate_network_objects"]))
total_network_objects = fields.Nested(
InputToAggregateNetworkObjectsSchema, required=True
)
statistic = Statistic()
aggregate_by = AggregateBy()

@post_load
def make_query_object(self, params, **kwargs):
return AggregateNetworkObjectsExposed(**params)


class AggregateNetworkObjectsExposed(BaseExposedQuery):
def __init__(self, *, total_network_objects, statistic, aggregate_by):
# Note: all input parameters need to be defined as attributes on `self`
Expand All @@ -57,3 +38,13 @@ def _flowmachine_query_obj(self):
statistic=self.statistic,
aggregate_by=self.aggregate_by,
)


class AggregateNetworkObjectsSchema(BaseSchema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["aggregate_network_objects"]))
total_network_objects = fields.Nested(TotalNetworkObjectsSchema, required=True)
statistic = Statistic()
aggregate_by = AggregateBy()

__model__ = AggregateNetworkObjectsExposed
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from abc import ABCMeta, abstractmethod
from marshmallow import Schema, fields
from marshmallow import fields

from .base_exposed_query import BaseExposedQuery
from .base_schema import BaseSchema
from .random_sample import RandomSampleSchema


class BaseQueryWithSamplingSchema(Schema):
class BaseQueryWithSamplingSchema(BaseSchema):
sampling = fields.Nested(RandomSampleSchema, allow_none=True)


Expand Down
18 changes: 18 additions & 0 deletions flowmachine/flowmachine/core/server/query_schemas/base_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from marshmallow import Schema, post_load


class BaseSchema(Schema):
@post_load
def remove_query_kind_if_present_and_load(self, params, **kwargs):
# Strip off query kind if present, because this isn't always wrapped in a OneOfSchema
return self.__model__(
**{
param_name: param_value
for param_name, param_value in params.items()
if param_name != "query_kind"
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from marshmallow import fields, post_load, Schema
from marshmallow import fields
from marshmallow.validate import OneOf

from flowmachine.features.location.redacted_consecutive_trips_od_matrix import (
Expand All @@ -13,25 +13,13 @@
ConsecutiveTripsODMatrix,
)
from . import BaseExposedQuery
from .base_schema import BaseSchema
from .custom_fields import SubscriberSubset
from .aggregation_unit import AggregationUnit, get_spatial_unit_obj

__all__ = ["ConsecutiveTripsODMatrixSchema", "ConsecutiveTripsODMatrixExposed"]


class ConsecutiveTripsODMatrixSchema(Schema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["consecutive_trips_od_matrix"]))
start_date = fields.Date(required=True)
end_date = fields.Date(required=True)
aggregation_unit = AggregationUnit()
subscriber_subset = SubscriberSubset()

@post_load
def make_query_object(self, params, **kwargs):
return ConsecutiveTripsODMatrixExposed(**params)


class ConsecutiveTripsODMatrixExposed(BaseExposedQuery):
def __init__(
self,
Expand Down Expand Up @@ -69,3 +57,14 @@ def _flowmachine_query_obj(self):
)
)
)


class ConsecutiveTripsODMatrixSchema(BaseSchema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["consecutive_trips_od_matrix"]))
start_date = fields.Date(required=True)
end_date = fields.Date(required=True)
aggregation_unit = AggregationUnit()
subscriber_subset = SubscriberSubset()

__model__ = ConsecutiveTripsODMatrixExposed
26 changes: 12 additions & 14 deletions flowmachine/flowmachine/core/server/query_schemas/daily_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from marshmallow import fields, post_load
from marshmallow import fields
from marshmallow.validate import OneOf

from flowmachine.features import daily_location
Expand All @@ -16,19 +16,6 @@
__all__ = ["DailyLocationSchema", "DailyLocationExposed"]


class DailyLocationSchema(BaseQueryWithSamplingSchema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["daily_location"]))
date = fields.Date(required=True)
method = fields.String(required=True, validate=OneOf(["last", "most-common"]))
aggregation_unit = AggregationUnit()
subscriber_subset = SubscriberSubset()

@post_load
def make_query_object(self, params, **kwargs):
return DailyLocationExposed(**params)


class DailyLocationExposed(BaseExposedQueryWithSampling):
def __init__(
self, date, *, method, aggregation_unit, subscriber_subset=None, sampling=None
Expand Down Expand Up @@ -56,3 +43,14 @@ def _unsampled_query_obj(self):
method=self.method,
subscriber_subset=self.subscriber_subset,
)


class DailyLocationSchema(BaseQueryWithSamplingSchema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["daily_location"]))
date = fields.Date(required=True)
method = fields.String(required=True, validate=OneOf(["last", "most-common"]))
aggregation_unit = AggregationUnit()
subscriber_subset = SubscriberSubset()

__model__ = DailyLocationExposed
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,18 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from marshmallow import Schema, fields, post_load
from marshmallow import fields
from marshmallow.validate import OneOf

from flowmachine.features.dfs import DFSTotalMetricAmount
from .base_exposed_query import BaseExposedQuery
from .base_schema import BaseSchema
from .custom_fields import DFSMetric
from .aggregation_unit import AggregationUnit

__all__ = ["DFSTotalMetricAmountSchema", "DFSTotalMetricAmountExposed"]


class DFSTotalMetricAmountSchema(Schema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["dfs_metric_total_amount"]))
metric = DFSMetric()
start_date = fields.Date(required=True)
end_date = fields.Date(required=True)
aggregation_unit = AggregationUnit()

@post_load
def make_query_object(self, params, **kwargs):
return DFSTotalMetricAmountExposed(**params)


class DFSTotalMetricAmountExposed(BaseExposedQuery):
def __init__(self, *, metric, start_date, end_date, aggregation_unit):
# Note: all input parameters need to be defined as attributes on `self`
Expand All @@ -50,3 +38,14 @@ def _flowmachine_query_obj(self):
end_date=self.end_date,
aggregation_unit=self.aggregation_unit,
)


class DFSTotalMetricAmountSchema(BaseSchema):
# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["dfs_metric_total_amount"]))
metric = DFSMetric()
start_date = fields.Date(required=True)
end_date = fields.Date(required=True)
aggregation_unit = AggregationUnit()

__model__ = DFSTotalMetricAmountExposed
32 changes: 15 additions & 17 deletions flowmachine/flowmachine/core/server/query_schemas/displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from marshmallow import fields, post_load
from marshmallow.validate import OneOf, Length
from marshmallow import fields
from marshmallow.validate import OneOf
from marshmallow_oneofschema import OneOfSchema

from flowmachine.features import Displacement
from .custom_fields import SubscriberSubset, Statistic
from .daily_location import DailyLocationSchema, DailyLocationExposed
from .modal_location import ModalLocationSchema, ModalLocationExposed
from .daily_location import DailyLocationSchema
from .modal_location import ModalLocationSchema
from .base_query_with_sampling import (
BaseQueryWithSamplingSchema,
BaseExposedQueryWithSampling,
Expand All @@ -27,19 +27,6 @@ class InputToDisplacementSchema(OneOfSchema):
}


class DisplacementSchema(BaseQueryWithSamplingSchema):
query_kind = fields.String(validate=OneOf(["displacement"]))
start = fields.Date(required=True)
stop = fields.Date(required=True)
statistic = Statistic()
reference_location = fields.Nested(InputToDisplacementSchema, many=False)
subscriber_subset = SubscriberSubset()

@post_load
def make_query_object(self, params, **kwargs):
return DisplacementExposed(**params)


class DisplacementExposed(BaseExposedQueryWithSampling):
def __init__(
self,
Expand Down Expand Up @@ -76,3 +63,14 @@ def _unsampled_query_obj(self):
reference_location=self.reference_location._flowmachine_query_obj,
subscriber_subset=self.subscriber_subset,
)


class DisplacementSchema(BaseQueryWithSamplingSchema):
query_kind = fields.String(validate=OneOf(["displacement"]))
start = fields.Date(required=True)
stop = fields.Date(required=True)
statistic = Statistic()
reference_location = fields.Nested(InputToDisplacementSchema, many=False)
subscriber_subset = SubscriberSubset()

__model__ = DisplacementExposed
32 changes: 16 additions & 16 deletions flowmachine/flowmachine/core/server/query_schemas/dummy_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from time import sleep

from marshmallow import Schema, fields, post_load
from marshmallow import fields
from marshmallow.validate import OneOf

from flowmachine.core.dummy_query import DummyQuery
Expand All @@ -12,21 +12,7 @@

__all__ = ["DummyQuerySchema", "DummyQueryExposed"]


class DummyQuerySchema(Schema):
"""
Dummy query useful for testing.
"""

# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["dummy_query"]))
dummy_param = fields.String(required=True)
aggregation_unit = AggregationUnit()
dummy_delay = fields.Integer(missing=0, required=False)

@post_load
def make_query_object(self, params, **kwargs):
return DummyQueryExposed(**params)
from .base_schema import BaseSchema


class DummyQueryExposed(BaseExposedQuery):
Expand All @@ -41,3 +27,17 @@ def __init__(self, dummy_param, aggregation_unit, dummy_delay):
def _flowmachine_query_obj(self):
sleep(self.dummy_delay)
return DummyQuery(dummy_param=self.dummy_param)


class DummyQuerySchema(BaseSchema):
"""
Dummy query useful for testing.
"""

# query_kind parameter is required here for claims validation
query_kind = fields.String(validate=OneOf(["dummy_query"]))
dummy_param = fields.String(required=True)
aggregation_unit = AggregationUnit()
dummy_delay = fields.Integer(missing=0, required=False)

__model__ = DummyQueryExposed
Loading

0 comments on commit 0d726b0

Please sign in to comment.