forked from avavilau/query-test-task
-
Notifications
You must be signed in to change notification settings - Fork 0
/
QueryCalcImpl.java
426 lines (325 loc) · 13.3 KB
/
QueryCalcImpl.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
package org.query.calc;
import java.io.*;
import java.nio.file.Path;
import java.util.Collection;
import java.util.Comparator;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.function.BiFunction;
public class QueryCalcImpl implements QueryCalc {
static final int LIMIT = 10;
static final Comparator<Row> ORDER = (o1, o2) -> {
final int comparison = compareDouble(o1.value, o2.value);
return (comparison != 0) ? -comparison : Integer.compare(o1.number, o2.number);
};
@Override
public void select(Path t1, Path t2, Path t3, Path output) throws IOException {
// - t1 is a file contains table "t1" with two columns "a" and "x". First line is a number of rows, then each
// line contains exactly one row, that contains two numbers parsable by Double.parse(): value for column a and
// x respectively.See test resources for examples.
// - t2 is a file contains table "t2" with columns "b" and "y". Same format.
// - t3 is a file contains table "t3" with columns "c" and "z". Same format.
// - output is table stored in the same format: first line is a number of rows, then each line is one row that
// contains two numbers: value for column a and s.
//
// Number of rows of all three tables lays in range [0, 1_000_000].
// It's guaranteed that full content of all three tables fits into RAM.
// It's guaranteed that full outer join of at least one pair (t1xt2 or t2xt3 or t1xt3) of tables can fit into RAM.
//
// TODO: Implement following query, put a reasonable effort into making it efficient from perspective of
// computation time, memory usage and resource utilization (in that exact order). You are free to use any lib
// from a maven central.
//
// SELECT a, SUM(X * y * z) AS s FROM
// t1 LEFT JOIN (SELECT * FROM t2 JOIN t3) AS t
// ON a < b + c
// GROUP BY a
// STABLE ORDER BY s DESC
// LIMIT 10;
//
// Note: STABLE is not a standard SQL command. It means that you should preserve the original order.
// In this context it means, that in case of tie on s-value you should prefer value of a, with a lower row number.
// In case multiple occurrences, you may assume that group has a row number of the first occurrence.
final Table a = read(t1);
final Table b = read(t2);
final Table c = read(t3);
final QueryPlanner planner = new QueryPlanner();
final QueryExecutor executor = planner.plan(a, b, c);
final Table result = executor.execute(a, b, c);
write(output, result);
}
static final class Row {
private final int number;
private final double key;
private final double value;
public Row(final int number, final double key, final double value) {
this.number = number;
this.key = key;
this.value = value;
}
public int getNumber() {
return number;
}
public double getKey() {
return key;
}
public double getValue() {
return value;
}
@Override
public String toString() {
return "#" + number + ": " + key + " -> " + value;
}
}
static final class Table {
final int size;
final int[] numbers;
final double[] keys;
final double[] values;
public Table(final int size, final int[] numbers, final double[] keys, final double[] values) {
this.size = size;
this.numbers = numbers;
this.keys = keys;
this.values = values;
}
}
static final class QueryPlanner {
/**
* A planner is based on complexity and does not take into account mechanical sympathy -
* memory hierarchy (cpu caches) and memory access pattern.
* <p>
* A - min size.
* B - mid size.
* C - max size.
* <p>
* Time: A * B * log(C) + C * log(C).
* Space: C.
*/
QueryExecutor plan(final Table a,
final Table b,
final Table c) {
if (a.size == 0) {
return new EmptyQueryExecutor();
}
if (a.size > b.size && a.size > c.size) {
return new LeftTableJoinQueryExecutor();
}
return new RightTableJoinQueryExecutor();
}
}
interface QueryExecutor {
Table execute(final Table a,
final Table b,
final Table c);
}
static final class EmptyQueryExecutor implements QueryExecutor {
private static final Table EMPTY_TABLE = new Table(0, new int[0], new double[0], new double[0]);
@Override
public Table execute(final Table a, final Table b, final Table c) {
return EMPTY_TABLE;
}
}
static final class RightTableJoinQueryExecutor implements QueryExecutor {
@Override
public Table execute(final Table a, Table b, Table c) {
if (b.size > c.size) {
final Table t = b;
b = c;
c = t;
}
final double[] suffix = suffix(c);
final TreeSet<Row> rows = new TreeSet<>(ORDER);
for (int i = 0; i < a.size; i++) {
final int number = a.numbers[i];
final double key = a.keys[i];
final double value = a.values[i];
final double sum = value * join(key, b, c, suffix);
final Row row = new Row(number, key, sum);
rows.add(row);
if (rows.size() > LIMIT) {
rows.pollLast();
}
}
return toTable(rows);
}
private static double[] suffix(final Table table) {
final double[] suffix = new double[table.size];
final double[] values = table.values;
double sum = 0.0;
for (int i = suffix.length - 1; i >= 0; i--) {
sum += values[i];
suffix[i] += sum;
}
return suffix;
}
private static double join(final double aKey, final Table b, final Table c, final double[] cSums) {
final double[] bKeys = b.keys;
final double[] bValues = b.values;
final int bSize = bKeys.length;
final double[] cKeys = c.keys;
final int cSize = cKeys.length;
double sum = 0.0;
for (int i = 0, j = cSize - 1; i < bSize; i++) {
final double bKey = bKeys[i]; // key goes up
final int k = search(aKey, bKey, cKeys, j); // key goes down
if (k < cSize) {
sum += bValues[i] * cSums[k];
j = k;
}
}
return sum;
}
private static int search(final double aKey, final double bKey, final double[] cKeys, int r) {
int l = 0;
while (l <= r) {
final int m = (l + r) >>> 1;
final double cKey = cKeys[m];
if (aKey < bKey + cKey) {
r = m - 1;
} else {
l = m + 1;
}
}
return l;
}
}
static final class LeftTableJoinQueryExecutor implements QueryExecutor {
@Override
public Table execute(final Table a, final Table b, final Table c) {
final TreeSet<Row> rows = new TreeSet<>(ORDER);
final double[] suffix = suffix(a, b, c);
for (int i = 0; i < a.size; i++) {
final int number = a.numbers[i];
final double key = a.keys[i];
final double value = a.values[i];
final double sum = value * suffix[i];
final Row row = new Row(number, key, sum);
rows.add(row);
if (rows.size() > LIMIT) {
rows.pollLast();
}
}
return toTable(rows);
}
private static double[] suffix(final Table a, final Table b, final Table c) {
final double[] suffix = new double[a.size];
final double[] bKeys = b.keys;
final double[] bValues = b.values;
final int bSize = bKeys.length;
for (int i = 0; i < bSize; i++) {
final double bKey = bKeys[i];
final double bValue = bValues[i];
join(bKey, bValue, a, c, suffix);
}
for (int i = suffix.length - 2; i >= 0; i--) {
suffix[i] += suffix[i + 1];
}
return suffix;
}
private static void join(final double bKey, final double bValue, final Table a, final Table c, final double[] aSums) {
final double[] aKeys = a.keys;
final double[] cKeys = c.keys;
final double[] cValues = c.values;
final int cSize = cKeys.length;
for (int i = 0, j = 0; i < cSize; i++) {
final double cKey = cKeys[i]; // key goes up
final int k = search(bKey, cKey, aKeys, j); // key goes up
if (k >= 0) {
final double cValue = cValues[i];
aSums[k] += bValue * cValue;
}
}
}
private static int search(final double bKey, final double cKey, double[] aKeys, int l) {
int r = aKeys.length - 1;
while (l <= r) {
final int m = (l + r) >>> 1;
final double aKey = aKeys[m];
if (aKey < bKey + cKey) {
l = m + 1;
} else {
r = m - 1;
}
}
return r;
}
}
private static Table read(final Path file) throws IOException {
try (final BufferedReader stream = new BufferedReader(new FileReader(file.toFile()))) {
final int size = Integer.parseUnsignedInt(stream.readLine());
final Comparator<Double> sorter = QueryCalcImpl::compareDouble;
final BiFunction<Row, Row, Row> merger = (left, right) -> {
final int number = Math.min(left.getNumber(), right.getNumber());
final double key = left.key;
final double sum = Double.sum(left.value, right.value);
return new Row(number, key, sum);
};
final TreeMap<Double, Row> uniques = new TreeMap<>(sorter);
for (int i = 0; i < size; i++) {
final String line = stream.readLine();
final int index = line.lastIndexOf(' ');
final double key = Double.parseDouble(line.substring(0, index));
final double value = Double.parseDouble(line.substring(index + 1));
final Row row = new Row(i, key, value);
verify(row);
uniques.merge(key, row, merger);
}
final int newSize = uniques.size();
final int[] numbers = new int[newSize];
final double[] keys = new double[newSize];
final double[] values = new double[newSize];
int index = 0;
for (final Row row : uniques.values()) {
numbers[index] = row.number;
keys[index] = row.key;
values[index] = row.value;
index++;
}
return new Table(newSize, numbers, keys, values);
}
}
private static void write(final Path file, final Table table) throws IOException {
try (final BufferedWriter stream = new BufferedWriter(new FileWriter(file.toFile()))) {
stream.write(Integer.toString(table.size));
stream.newLine();
for (int i = 0; i < table.size; i++) {
final double key = table.keys[i];
final double value = table.values[i];
stream.write(Double.toString(key));
stream.write(' ');
stream.write(Double.toString(value));
stream.newLine();
}
}
}
private static void verify(final Row row) {
final double key = row.getKey();
final double value = row.getValue();
if (!Double.isFinite(key)) {
throw new IllegalArgumentException("Not supported key: " + key);
}
if (!Double.isFinite(value)) {
throw new IllegalArgumentException("Not supported value: " + value);
}
}
private static Table toTable(final Collection<? extends Row> rows) {
final int size = rows.size();
final int[] numbers = new int[size];
final double[] keys = new double[size];
final double[] values = new double[size];
int index = 0;
for (final Row row : rows) {
numbers[index] = row.number;
keys[index] = row.key;
values[index] = row.value;
index++;
}
return new Table(size, numbers, keys, values);
}
static int compareDouble(final double left, final double right) {
if (left == right) {
return 0;
}
return Double.compare(left, right); // 0.0 != -0.0, but we want ==
}
}