Skip to content

Commit

Permalink
Update workflow (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Mar 12, 2024
1 parent 36bcd72 commit a00d25c
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 285 deletions.
98 changes: 53 additions & 45 deletions dashboard.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,70 @@
import dask.dataframe as dd
import plotly.express as px
import time

import pandas as pd
import streamlit as st

from pipeline.settings import RESULTS_DIR


@st.cache_data
def get_data(region, part_type):
return dd.read_parquet(
RESULTS_DIR / region / part_type.upper() / "*.parquet"
).compute()

def get_data(segment):
return pd.read_parquet(RESULTS_DIR / f"{segment.lower()}.snappy.parquet")

description = """
### Recommended Suppliers
_Some text that explains the business problem being addressed..._

This query finds which supplier should be selected to place an order for a given part in a given region.
st.markdown(
"""
### Top Unshipped Orders
_Top 50 unshipped orders with the highest revenue._
"""
st.markdown(description)
regions = list(map(str.title, ["EUROPE", "AFRICA", "AMERICA", "ASIA", "MIDDLE EAST"]))
region = st.selectbox(
"Region",
regions,
index=None,
placeholder="Please select a region...",
)
part_types = list(map(str.title, ["COPPER", "BRASS", "TIN", "NICKEL", "STEEL"]))
part_type = st.selectbox(
"Part Type",
part_types,

SEGMENTS = ["automobile", "building", "furniture", "machinery", "household"]


def files_exist():
# Do we have all the files needed for the dashboard?
files = list(RESULTS_DIR.rglob("*.snappy.parquet"))
return len(files) == len(SEGMENTS)


with st.spinner("Waiting for data..."):
while not files_exist():
time.sleep(5)

segments = list(
map(str.title, ["automobile", "building", "furniture", "machinery", "household"])
)
segment = st.selectbox(
"Segment",
segments,
index=None,
placeholder="Please select a part type...",
placeholder="Please select a product segment...",
)
if region and part_type:
df = get_data(region, part_type)
if segment:
df = get_data(segment)
df = df.drop(columns="o_shippriority")
df["l_orderkey"] = df["l_orderkey"].map(lambda x: f"{x:09}")
df["revenue"] = df["revenue"].round(2)
df = df.rename(
columns={
"n_name": "Country",
"s_name": "Supplier",
"s_acctbal": "Balance",
"p_partkey": "Part ID",
"l_orderkey": "Order ID",
"o_order_time": "Date Ordered",
"revenue": "Revenue",
}
)
maxes = df.groupby("Country").Balance.idxmax()
data = df.loc[maxes]
figure = px.choropleth(
data,
locationmode="country names",
locations="Country",
featureidkey="Supplier",
color="Balance",
color_continuous_scale="viridis",
hover_data=["Country", "Supplier", "Balance"],

df = df.set_index("Order ID")
st.dataframe(
df.style.format({"Revenue": "${:,}"}),
column_config={
"Date Ordered": st.column_config.DateColumn(
"Date Ordered",
format="MM/DD/YYYY",
help="Date order was placed",
),
"Revenue": st.column_config.NumberColumn(
"Revenue (in USD)",
help="Total revenue of order",
),
},
)
st.plotly_chart(figure, theme="streamlit", use_container_width=True)
on = st.toggle("Show data")
if on:
st.write(
df[["Country", "Supplier", "Balance", "Part ID"]], use_container_width=True
)
6 changes: 5 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@ dependencies:
- s3fs
- universal_pathlib <0.2.0
- boto3
- dask-deltatable
- deltalake=0.15.3
# - dask-deltatable
- pip
- pip:
- git+https://github.com/fjetter/dask-deltatable.git@dask_expr
61 changes: 61 additions & 0 deletions pipeline/dashboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import shlex
import subprocess

import coiled
import requests
from prefect import flow
from rich import print

from .settings import DASHBOARD_FILE, LOCAL, REGION

port = 8080
name = "etl-tpch-dashboard"
subdomain = "etl-tpch"


def deploy():
print("[green]Deploying dashboard...[/green]")
cmd = f"streamlit run {DASHBOARD_FILE} --server.port {port} --server.headless true"
if LOCAL:
subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
else:
cmd = f"""
coiled run \
--region {REGION} \
--vm-type t3.medium \
-f dashboard.py \
-f pipeline \
--subdomain {subdomain} \
--port {port} \
-e AWS_ACCESS_KEY_ID={os.environ['AWS_ACCESS_KEY_ID']} \
-e AWS_SECRET_ACCESS_KEY={os.environ['AWS_SECRET_ACCESS_KEY']} \
--detach \
--keepalive '520 weeks' \
--name {name} \
-- \
{cmd}
"""
subprocess.run(shlex.split(cmd))
print(f"Dashboard is available at [blue]{get_address()}[/blue] :rocket:")


def get_address():
if LOCAL:
return f"http://0.0.0.0:{port}"
else:
with coiled.Cloud() as cloud:
account = cloud.default_account
return f"http://{subdomain}.{account}.dask.host:{port}"


@flow(log_prints=True)
def deploy_dashboard():
address = get_address()
try:
r = requests.get(address)
r.raise_for_status()
except Exception:
deploy()
else:
print("Dashboard is healthy")
54 changes: 48 additions & 6 deletions pipeline/data.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import datetime
import os
import uuid

import coiled
import duckdb
import pandas as pd
import psutil
from dask.distributed import print
from prefect import flow, task

from .settings import LOCAL, PROCESSED_DIR, REGION, STAGING_DIR, fs, lock_generate


def new_time(t, t_start=None, t_end=None):

d = pd.Timestamp("1998-12-31") - pd.Timestamp("1992-01-01")
return t_start + (t - pd.Timestamp("1992-01-01")) * ((t_end - t_start) / d)


@task(log_prints=True)
@coiled.function(
name="data-generation",
local=LOCAL,
region=REGION,
keepalive="5 minutes",
vm_type="m6i.2xlarge",
tags={"workflow": "etl-tpch"},
)
def generate(scale: float, path: os.PathLike) -> None:
Expand All @@ -42,7 +50,8 @@ def generate(scale: float, path: os.PathLike) -> None:
.arrow()
.column("table_name")
)
for table in map(str, tables):
now = pd.Timestamp.now()
for table in reversed(sorted(map(str, tables))):
if table in static_tables and (
list((STAGING_DIR / table).rglob("*.json"))
or list((PROCESSED_DIR / table).rglob("*.parquet"))
Expand All @@ -51,15 +60,49 @@ def generate(scale: float, path: os.PathLike) -> None:
continue
print(f"Exporting table: {table}")
stmt = f"""select * from {table}"""
df = con.sql(stmt).arrow()
df = con.sql(stmt).df()

# Make order IDs unique across multiple data generation cycles
if table == "orders":
# Generate new, random uuid order IDs
df["o_orderkey_new"] = pd.Series(
(uuid.uuid4().hex for _ in range(df.shape[0])),
dtype="string[pyarrow]",
)
orderkey_new = df[["o_orderkey", "o_orderkey_new"]].set_index(
"o_orderkey"
)
df = df.drop(columns="o_orderkey").rename(
columns={"o_orderkey_new": "o_orderkey"}
)
elif table == "lineitem":
# Join with `orderkey_new` mapping to convert old order IDs to new order IDs
df = (
df.set_index("l_orderkey")
.join(orderkey_new)
.reset_index(drop=True)
.rename(columns={"o_orderkey_new": "l_orderkey"})
)

# Shift times to be more recent
if table == "lineitem":
df["l_shipdate"] = new_time(
df["l_shipdate"], t_start=now, t_end=now + pd.Timedelta("7 days")
)
df = df.rename(columns={"l_shipdate": "l_ship_time"})
cols = [c for c in df.columns if "date" in c]
df[cols] = new_time(
df[cols], t_start=now - pd.Timedelta("15 minutes"), t_end=now
)
df = df.rename(columns={c: c.replace("date", "_time") for c in cols})

outfile = (
path
/ table
/ f"{table}_{datetime.datetime.now().isoformat().split('.')[0]}.json"
)
fs.makedirs(outfile.parent, exist_ok=True)
df.to_pandas().to_json(
df.to_json(
outfile,
date_format="iso",
orient="records",
Expand All @@ -73,7 +116,6 @@ def generate(scale: float, path: os.PathLike) -> None:
def generate_data():
with lock_generate:
generate(
scale=0.01,
scale=1,
path=STAGING_DIR,
)
generate.fn.client.restart(wait_for_workers=False)
9 changes: 0 additions & 9 deletions pipeline/monitor.py

This file was deleted.

22 changes: 5 additions & 17 deletions pipeline/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from prefect.tasks import exponential_backoff

from .settings import (
ARCHIVE_DIR,
LOCAL,
PROCESSED_DIR,
REGION,
Expand All @@ -29,7 +28,7 @@
name="data-etl",
local=LOCAL,
region=REGION,
keepalive="5 minutes",
vm_type="m6i.2xlarge",
tags={"workflow": "etl-tpch"},
)
def json_file_to_parquet(file):
Expand All @@ -42,19 +41,10 @@ def json_file_to_parquet(file):
deltalake.write_deltalake(
outfile, data, mode="append", storage_options=storage_options
)
print(f"Saved {outfile}")
fs.rm(str(file))
return file


@task
def archive_json_file(file):
outfile = ARCHIVE_DIR / file.relative_to(STAGING_DIR)
fs.makedirs(outfile.parent, exist_ok=True)
fs.mv(str(file), str(outfile))

return outfile


def list_new_json_files():
return list(STAGING_DIR.rglob("*.json"))

Expand All @@ -63,18 +53,16 @@ def list_new_json_files():
def json_to_parquet():
with lock_json_to_parquet:
files = list_new_json_files()
files = json_file_to_parquet.map(files)
futures = archive_json_file.map(files)
futures = json_file_to_parquet.map(files)
for f in futures:
print(f"Archived {str(f.result())}")
print(f"Processed {str(f.result())}")


@task(log_prints=True)
@coiled.function(
name="data-etl",
local=LOCAL,
region=REGION,
keepalive="5 minutes",
vm_type="m6i.xlarge",
tags={"workflow": "etl-tpch"},
)
def compact(table):
Expand Down
Loading

0 comments on commit a00d25c

Please sign in to comment.