-
Notifications
You must be signed in to change notification settings - Fork 76
/
sample_cudapoa.cpp
335 lines (298 loc) · 13.3 KB
/
sample_cudapoa.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
/*
* Copyright 2019-2020 NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <file_location.hpp>
#include <claraparabricks/genomeworks/cudapoa/cudapoa.hpp>
#include <claraparabricks/genomeworks/cudapoa/batch.hpp>
#include <claraparabricks/genomeworks/utils/signed_integer_utils.hpp>
#include <claraparabricks/genomeworks/utils/cudautils.hpp>
#include <claraparabricks/genomeworks/cudapoa/utils.hpp> // for get_multi_batch_sizes()
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <unistd.h>
#include <getopt.h>
using namespace claraparabricks::genomeworks;
using namespace claraparabricks::genomeworks::cudapoa;
std::unique_ptr<Batch> initialize_batch(bool msa, const BatchConfig& batch_size)
{
// Get device information.
int32_t device_count = 0;
GW_CU_CHECK_ERR(cudaGetDeviceCount(&device_count));
assert(device_count > 0);
size_t total = 0, free = 0;
cudaSetDevice(0); // Using first GPU for sample.
cudaMemGetInfo(&free, &total);
// Initialize internal logging framework.
Init();
// Initialize CUDAPOA batch object for batched processing of POAs on the GPU.
const int32_t device_id = 0;
cudaStream_t stream = 0;
int64_t mem_per_batch = 0.9 * free; // Using 90% of GPU available memory for CUDAPOA batch.
const int32_t mismatch_score = -6, gap_score = -8, match_score = 8;
std::unique_ptr<Batch> batch = create_batch(device_id,
stream,
mem_per_batch,
msa ? OutputType::msa : OutputType::consensus,
batch_size,
gap_score,
mismatch_score,
match_score);
return batch;
}
void process_batch(Batch* batch, bool msa_flag, bool print, std::vector<int32_t>& list_of_group_ids, int id_offset)
{
batch->generate_poa();
std::string error_message, error_hint;
StatusType status = StatusType::success;
if (msa_flag)
{
// Grab MSA results for all POA groups in batch.
std::vector<std::vector<std::string>> msa; // MSA per group
std::vector<StatusType> output_status; // Status of MSA generation per group
status = batch->get_msa(msa, output_status);
if (status != StatusType::success)
{
decode_error(status, error_message, error_hint);
std::cerr << "Could not generate MSA for batch : " << std::endl;
std::cerr << error_message << std::endl
<< error_hint << std::endl;
}
for (int32_t g = 0; g < get_size(msa); g++)
{
if (output_status[g] != StatusType::success)
{
decode_error(output_status[g], error_message, error_hint);
std::cerr << "Error generating MSA for POA group " << list_of_group_ids[g + id_offset] << std::endl;
std::cerr << error_message << std::endl
<< error_hint << std::endl;
}
else
{
if (print)
{
for (const auto& alignment : msa[g])
{
std::cout << alignment << std::endl;
}
}
}
}
}
else
{
// Grab consensus results for all POA groups in batch.
std::vector<std::string> consensus; // Consensus string for each POA group
std::vector<std::vector<uint16_t>> coverage; // Per base coverage for each consensus
std::vector<StatusType> output_status; // Status of consensus generation per group
status = batch->get_consensus(consensus, coverage, output_status);
if (status != StatusType::success)
{
decode_error(status, error_message, error_hint);
std::cerr << "Could not generate consensus for batch : " << std::endl;
std::cerr << error_message << std::endl
<< error_hint << std::endl;
}
for (int32_t g = 0; g < get_size(consensus); g++)
{
if (output_status[g] != StatusType::success)
{
decode_error(output_status[g], error_message, error_hint);
std::cerr << "Error generating consensus for POA group " << list_of_group_ids[g + id_offset] << std::endl;
std::cerr << error_message << std::endl
<< error_hint << std::endl;
}
else
{
if (print)
{
std::cout << consensus[g] << std::endl;
}
}
}
}
}
int main(int argc, char** argv)
{
// Process options
int c = 0;
bool msa = false;
bool long_read = false;
BandMode band_mode = BandMode::adaptive_band; // 0: full, 1: static-band, 2: adaptive-band, 3- static-band-traceback 4- adaptive-band-traceback
bool help = false;
bool print = false;
bool print_graph = false;
int32_t band_width = 256; // default band-width for static bands, and min band-width in adaptive bands
while ((c = getopt(argc, argv, "mlb:pgh")) != -1)
{
switch (c)
{
case 'm':
msa = true;
break;
case 'l':
long_read = true;
break;
case 'b':
if (std::stoi(optarg) < 0 || std::stoi(optarg) > 4)
{
throw std::runtime_error("band-mode must be either 0 for full bands, 1 for static bands, 2 for adaptive bands, 3 and 4 for static and adaptive bands with traceback");
}
band_mode = static_cast<BandMode>(std::stoi(optarg));
break;
case 'p':
print = true;
break;
case 'g':
print_graph = true;
break;
case 'h':
help = true;
break;
}
}
if (help)
{
std::cout << "CUDAPOA API sample program. Runs consensus or MSA generation on pre-canned data." << std::endl;
std::cout << "Usage:" << std::endl;
std::cout << "./sample_cudapoa [-m] [-h]" << std::endl;
std::cout << "-m : Generate MSA (if not provided, generates consensus by default)" << std::endl;
std::cout << "-l : Perform long-read sample (if not provided, will run short-read sample by default)" << std::endl;
std::cout << "-b : Sets band mode 0: full-alignment, 1: static band, 2: adaptive band , will run adaptive band by default)" << std::endl;
std::cout << "-p : Print the MSA or consensus output to stdout" << std::endl;
std::cout << "-g : Print POA graph in dot format, this option is only for long-read sample" << std::endl;
std::cout << "-h : Print help message" << std::endl;
std::exit(0);
}
// Load input data. Each window is represented as a vector of strings. The sample
// data has many such windows to process, hence the data is loaded into a vector
// of vector of strings.
std::vector<std::vector<std::string>> windows;
if (long_read)
{
const std::string input_file = std::string(CUDAPOA_BENCHMARK_DATA_DIR) + "/sample-bonito.txt";
parse_cudapoa_file(windows, input_file, -1);
}
else
{
const std::string input_file = std::string(CUDAPOA_BENCHMARK_DATA_DIR) + "/sample-windows.txt";
parse_cudapoa_file(windows, input_file, 1000);
}
// Create a vector of POA groups based on windows
std::vector<Group> poa_groups(windows.size());
for (int32_t i = 0; i < get_size(windows); ++i)
{
Group& group = poa_groups[i];
// Create a new entry for each sequence and add to the group.
for (const auto& seq : windows[i])
{
Entry poa_entry{};
poa_entry.seq = seq.c_str();
poa_entry.length = seq.length();
poa_entry.weights = nullptr;
group.push_back(poa_entry);
}
}
// for error code message
std::string error_message, error_hint;
// analyze the POA groups and create a minimal set of batches to process them all
std::vector<BatchConfig> list_of_batch_sizes;
std::vector<std::vector<int32_t>> list_of_groups_per_batch;
get_multi_batch_sizes(list_of_batch_sizes, list_of_groups_per_batch, poa_groups, msa, band_width, band_mode);
int32_t group_count_offset = 0;
for (int32_t b = 0; b < get_size(list_of_batch_sizes); b++)
{
auto& batch_size = list_of_batch_sizes[b];
auto& batch_group_ids = list_of_groups_per_batch[b];
// Initialize batch.
std::unique_ptr<Batch> batch = initialize_batch(msa, batch_size);
// Loop over all the POA groups for the current batch, add them to the batch and process them.
int32_t group_count = 0;
for (int32_t i = 0; i < get_size(batch_group_ids);)
{
Group& group = poa_groups[batch_group_ids[i]];
std::vector<StatusType> seq_status;
StatusType status = batch->add_poa_group(seq_status, group);
// NOTE: If number of batch groups smaller than batch capacity, then run POA generation
// once last POA group is added to batch.
if (status == StatusType::exceeded_maximum_poas || (i == get_size(batch_group_ids) - 1))
{
// at least one POA should have been added before processing the batch
if (batch->get_total_poas() > 0)
{
// No more POA groups can be added to batch. Now process batch.
process_batch(batch.get(), msa, print, batch_group_ids, group_count);
if (print_graph && long_read)
{
std::vector<DirectedGraph> graph;
std::vector<StatusType> graph_status;
batch->get_graphs(graph, graph_status);
for (auto& g : graph)
{
std::cout << g.serialize_to_dot() << std::endl;
}
}
// After MSA/consensus is generated for batch, reset batch to make room for next set of POA groups.
batch->reset();
// In case that number of batch groups is more than the capacity available on GPU, the for loop breaks into smaller number of groups.
// if adding group i in batch->add_poa_group is not successful, it wont be processed in this iteration, therefore we print i-1
// to account for the fact that group i was excluded at this round.
if (status == StatusType::success)
{
std::cout << "Processed groups " << group_count + group_count_offset << " - " << i + group_count_offset << " (batch " << b << ")" << std::endl;
}
else
{
std::cout << "Processed groups " << group_count + group_count_offset << " - " << i - 1 + group_count_offset << " (batch " << b << ")" << std::endl;
}
}
else
{
// the POA was too large to be added to the GPU, skip and move on
std::cout << "Could not add POA group " << batch_group_ids[i] << " to batch " << b << std::endl;
i++;
}
group_count = i;
}
if (status == StatusType::success)
{
// Check if all sequences in POA group wre added successfully.
int32_t num_dropped_seq = 0;
for (const auto& s : seq_status)
{
if (s == StatusType::exceeded_maximum_sequence_size)
{
num_dropped_seq++;
}
}
if (num_dropped_seq > 0)
{
std::cerr << "Dropping " << num_dropped_seq << " sequence(s) in POA group " << batch_group_ids[i] << " because it exceeded maximum size" << std::endl;
}
i++;
}
if (status != StatusType::exceeded_maximum_poas && status != StatusType::success)
{
decode_error(status, error_message, error_hint);
std::cerr << "Could not add POA group " << batch_group_ids[i] << " to batch " << b << std::endl;
std::cerr << error_message << std::endl
<< error_hint << std::endl;
i++;
}
}
group_count_offset += get_size(batch_group_ids);
}
return 0;
}