-
Notifications
You must be signed in to change notification settings - Fork 123
/
transform.py
190 lines (149 loc) · 6.07 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import textwrap
import dlt
import ibis
import ibis.expr.types as ir
import openai
from hamilton.function_modifiers import pipe, source, step
from hamilton.htypes import Collect, Parallelizable
def db_con(pipeline: dlt.Pipeline) -> ibis.BaseBackend:
"""Connect to the Ibis backend"""
backend = ibis.connect(f"{pipeline.pipeline_name}.duckdb")
ibis.set_backend(backend)
return backend
def channel(selected_channels: list[str]) -> Parallelizable[str]:
"""Iterate over channels for which to load messages and replies"""
for channel in selected_channels:
yield channel
def _epoch_microseconds(timestamp: ir.TimestampColumn) -> ir.StringColumn:
"""Convert the timestamp value to a string with microsecond precision
Required to meet the Slack format.
"""
seconds_from_epoch = timestamp.epoch_seconds()
microseconds = timestamp.microsecond() / int(10e5)
return (seconds_from_epoch + microseconds).cast(str)
def channel_message(
channel: str,
db_con: ibis.BaseBackend,
pipeline: dlt.Pipeline,
) -> ir.Table:
"""Load table containing parent messages of a channel.
the timestamps `thread_ts` and `ts` are converted to strings.
`thread_ts` is not None if the message has replies / started a thread. Otherwise,
`thread_ts` == `ts`. Coalesce is used to fill these None values with `ts`
Slack reference: https://api.slack.com/messaging/retrieving#finding_threads
"""
return (
db_con.table(
f"{channel}_message",
schema=pipeline.dataset_name,
database=pipeline.pipeline_name,
)
.mutate(
thread_ts=_epoch_microseconds(ibis._.thread_ts).cast(str),
ts=_epoch_microseconds(ibis._.ts).cast(str),
)
.mutate(thread_ts=ibis.coalesce(ibis._.thread_ts, ibis._.ts))
)
def channel_replies(
channel: str,
db_con: ibis.BaseBackend,
pipeline: dlt.Pipeline,
) -> ir.Table:
"""Create table for replies"""
return db_con.table(
f"{channel}_replies_message",
schema=pipeline.dataset_name,
database=pipeline.pipeline_name,
)
def channel_threads(
channel_message: ir.Table,
channel_replies: ir.Table,
) -> ir.Table:
"""Union of parent messages and replies. Sort by thread start, then message timestamp"""
columns = ["channel", "thread_ts", "ts", "user", "text", "_dlt_load_id", "_dlt_id"]
return ibis.union(
channel_message.select(columns),
channel_replies.select(columns),
).order_by([ibis._.thread_ts, ibis._.ts])
def channels_collection(channel_threads: Collect[ir.Table]) -> ir.Table:
"""Collect `channel_threads` for all channels"""
return ibis.union(*channel_threads)
def _format_messages(threads: ir.Table) -> ir.Table:
"""Assign a user id per thread and prefix messages with it"""
thread_user_id_expr = (ibis.dense_rank().over(order_by="user") + 1).cast(str)
return threads.group_by("thread_ts").mutate(
message=thread_user_id_expr.concat(": ", ibis._.text)
)
def _aggregate_thread(threads: ir.Table) -> ir.Table:
"""Create threads as a single string by concatenating messages
Functions decorates with `@ibis.udf` are loaded by the Ibis backend.
They aren't meant to be called directly.
ref: https://ibis-project.org/how-to/extending/builtin
"""
@ibis.udf.agg.builtin(name="string_agg")
def _string_agg(arg, sep: str = "\n ") -> str:
raise NotImplementedError
@ibis.udf.agg.builtin(name="array_agg")
def _array_agg(arg) -> list[str]:
raise NotImplementedError
return threads.group_by("thread_ts").agg(
thread=_string_agg(ibis._.message),
num_messages=ibis._.count(),
users=_array_agg(ibis._.user).unique(),
_dlt_load_id=ibis._._dlt_load_id.max(),
_dlt_id=_array_agg(ibis._._dlt_id),
)
def summary_prompt() -> str:
"""LLM prompt to summarize Slack thread"""
return textwrap.dedent(
"""Hamilton is an open source library to write dataflows in Python. It is used by developers for data engineering, data science, machine learning, and LLM workflows.
Next is a discussion thread about Hamilton started by User1. Complete these tasks: identify the issue raised by User1, summarize the discussion, indicate if you think the issue was resolved.
DISCUSSION THREAD
{text}
"""
)
def _summary(threads: ir.Table, prompt: str) -> ir.Table:
"""Generate a summary for each thread.
Uses a scalar Python UDF executed by the backend.
"""
@ibis.udf.scalar.python
def _openai_completion_udf(
text: str, prompt_template: str
) -> str: # Ibis requires `str` type hint even if None is allowed
"""Fill `prompt` with `text` and use OpenAI chat completion.
Returns None if:
- `text` is empty
- `content` is too long
- OpenAI call fails
"""
if len(text) == 0:
return None
content = prompt_template.format(text=text)
if len(content) // 4 > 8191:
return None
client = openai.OpenAI()
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": content}],
)
try:
output = response.choices[0].message.content
except Exception:
output = None
return output
return threads.mutate(summary=_openai_completion_udf(threads.thread, prompt))
# @pipe operator facilitates managing the function/node namespace
@pipe(
step(_format_messages), step(_aggregate_thread), step(_summary, prompt=source("summary_prompt"))
)
def threads(channels_collection: ir.Table) -> ir.Table:
"""Create `threads` table by formatting, aggregating messages,
and generating summaries.
"""
return channels_collection
def insert_threads(threads: ir.Table) -> int:
"""Save `threads` table and return row count."""
db_con = ibis.get_backend()
threads_table = db_con.create_table("threads", threads)
db_con.insert("threads", threads)
return int(threads_table.count().execute())