Skip to content

Commit

Permalink
Add file name callback for stream-gears download.
Browse files Browse the repository at this point in the history
  • Loading branch information
CoolZxp committed Apr 25, 2024
1 parent 8688dbe commit 6d6a4db
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 42 deletions.
13 changes: 9 additions & 4 deletions crates/biliup/src/downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::collections::HashMap;
use tracing::{debug, error, info};

use crate::downloader::util::Segmentable;
use crate::downloader::util::{LifecycleFile, Segmentable};

use crate::client::StatelessClient;
use std::str::FromStr;
use crate::downloader::extractor::CallbackFn;

pub mod error;
pub mod extractor;
Expand All @@ -24,6 +25,7 @@ pub async fn download(
headers: HeaderMap,
file_name: &str,
segment: Segmentable,
file_name_hook: Option<CallbackFn>,
) -> anyhow::Result<()> {
let client = StatelessClient::new(headers);
let response = client.retryable(url).await?;
Expand All @@ -40,14 +42,16 @@ pub async fn download(
Ok((_i, header)) => {
debug!("header: {header:#?}");
info!("Downloading {}...", url);
httpflv::download(connection, file_name, segment).await;
let file = LifecycleFile::new(file_name, "flv", file_name_hook);
httpflv::download(connection, file, segment).await;
}
Err(Err::Incomplete(needed)) => {
error!("needed: {needed:?}")
}
Err(e) => {
error!("{e}");
hls::download(url, &client, file_name, segment).await?;
let file = LifecycleFile::new(file_name, "ts", file_name_hook);
hls::download(url, &client, file, segment).await?;
}
}
Ok(())
Expand Down Expand Up @@ -86,7 +90,7 @@ pub fn construct_headers(hash_map: HashMap<String, String>) -> HeaderMap {
#[cfg(test)]
mod tests {
use crate::downloader::download;
use crate::downloader::util::Segmentable;
use crate::downloader::util::{Segmentable};
use anyhow::Result;
use reqwest::header::{HeaderMap, HeaderValue, REFERER};

Expand All @@ -106,6 +110,7 @@ mod tests {
"testdouyu%Y-%m-%dT%H_%M_%S",
// Segment::Size(20 * 1024 * 1024, 0),
Segmentable::new(Some(std::time::Duration::from_secs(6000)), None),
None,
)?;
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion crates/biliup/src/downloader/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Site {
}
Extension::Ts => {
let file = LifecycleFile::new(&fmt_file_name, "ts", hook);
hls::download(&self.direct_url, &self.client, &file.fmt_file_name, segment).await?
hls::download(&self.direct_url, &self.client, file, segment).await?
}
}
Ok(())
Expand Down
60 changes: 37 additions & 23 deletions crates/biliup/src/downloader/hls.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
use crate::downloader::error::Result;
use crate::downloader::util::{format_filename, Segmentable};
use crate::downloader::util::{LifecycleFile, Segmentable};
use m3u8_rs::Playlist;

use std::fs::File;
use std::io::{BufWriter, Write};
use std::time::Duration;
use tracing::{debug, error, info, warn};
use tracing::{debug, info, warn};
use url::Url;

use crate::client::StatelessClient;

pub async fn download(
url: &str,
client: &StatelessClient,
file_name: &str,
file: LifecycleFile,
mut splitting: Segmentable,
) -> Result<()> {
info!("Downloading {}...", url);
let resp = client.retryable(url).await?;
info!("{}", resp.status());
// let mut resp = resp.bytes_stream();
let bytes = resp.bytes().await?;
let mut ts_file = TsFile::new(file_name);
let mut ts_file = TsFile::new(file)?;

let mut media_url = Url::parse(url)?;
let mut pl = match m3u8_rs::parse_playlist(&bytes) {
Expand Down Expand Up @@ -62,7 +62,7 @@ pub async fn download(
debug!("Yield segment");
if segment.discontinuity {
warn!("#EXT-X-DISCONTINUITY");
ts_file = TsFile::new(file_name);
ts_file.create_new()?;
// splitting = Segment::from_seg(splitting);
splitting.reset();
}
Expand All @@ -71,12 +71,11 @@ pub async fn download(
client,
&mut ts_file.buf_writer,
)
.await?;
.await?;
splitting.increase_size(length);
splitting.increase_time(Duration::from_secs(segment.duration as u64));
if splitting.needed() {
ts_file = TsFile::new(file_name);
info!("{} splitting.{splitting:?}", ts_file.name);
ts_file.create_new()?;
splitting.reset();
}
previous_last_segment = seq;
Expand Down Expand Up @@ -110,34 +109,49 @@ async fn download_to_file(url: Url, client: &StatelessClient, out: &mut impl Wri

pub struct TsFile {
pub buf_writer: BufWriter<File>,
pub name: String,
pub file: LifecycleFile,
}

impl TsFile {
pub fn new(file_name: &str) -> Self {
let file_name = format_filename(file_name);
let out = File::create(format!("{file_name}.ts.part")).expect("Unable to create ts file.");
let buf_writer = BufWriter::new(out);
Self {
buf_writer,
name: file_name,
}
pub fn new(mut file: LifecycleFile) -> std::io::Result<Self> {
let path = file.create()?;
Ok(Self {
buf_writer: Self::create(path)?,
file,
})
}

pub fn create_new(&mut self) -> std::io::Result<()> {
self.file.rename();
let path = self.file.create()?;
self.buf_writer = Self::create(path)?;
Ok(())
}

fn create<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<BufWriter<File>> {
let path = path.as_ref();
let out = match File::create(path) {
Ok(o) => o,
Err(e) => {
return Err(std::io::Error::new(
e.kind(),
format!("Unable to create file {}", path.display()),
));
}
};
info!("create file {}", path.display());
Ok(BufWriter::new(out))
}
}

impl Drop for TsFile {
fn drop(&mut self) {
std::fs::rename(
format!("{}.ts.part", self.name),
format!("{}.ts", self.name),
)
.unwrap_or_else(|e| error!("{e}"))
self.file.rename()
}
}

#[cfg(test)]
mod tests {

use anyhow::Result;
use reqwest::Url;

Expand Down
20 changes: 11 additions & 9 deletions crates/biliup/src/downloader/httpflv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ use std::time::Duration;
use tokio::time::timeout;
use tracing::{info, warn};

pub async fn download(connection: Connection, file_name: &str, segment: Segmentable) {
let file: LifecycleFile = LifecycleFile::new(file_name, "flv", None);
pub async fn download(connection: Connection, file: LifecycleFile, segment: Segmentable) {
let file_name = file.file_name.clone();
match parse_flv(connection, file, segment).await {
Ok(_) => {
info!("Done... {file_name}");
info!("Done... {}", file_name);
}
Err(e) => {
warn!("{e}")
Expand Down Expand Up @@ -139,10 +139,10 @@ pub(crate) async fn parse_flv(
match &flv_tag {
FlvTag {
data:
TagDataHeader::Video {
frame_type: FrameType::Key,
..
},
TagDataHeader::Video {
frame_type: FrameType::Key,
..
},
..
} => {
let timestamp = flv_tag.header.timestamp as u64;
Expand Down Expand Up @@ -235,7 +235,10 @@ impl Connection {
}
}

pub async fn read_frame(&mut self, chunk_size: usize) -> crate::downloader::error::Result<Bytes> {
pub async fn read_frame(
&mut self,
chunk_size: usize,
) -> crate::downloader::error::Result<Bytes> {
// let mut buf = [0u8; 8 * 1024];
loop {
if chunk_size <= self.buffer.len() {
Expand Down Expand Up @@ -270,7 +273,6 @@ impl Connection {

#[cfg(test)]
mod tests {

use anyhow::Result;
use bytes::{Buf, BufMut, BytesMut};

Expand Down
2 changes: 1 addition & 1 deletion crates/biliup/src/uploader/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ impl Credential {
.error_for_status()?;
let full = raw.bytes().await?;

let res: ResponseData<ResponseValue> = serde_json::from_slice(&full).map_err(|e| Kind::Custom(format!("error decoding response body, content: {:#?}", String::from_utf8_lossy(&full))))?;
let res: ResponseData<ResponseValue> = serde_json::from_slice(&full).map_err(|_| Kind::Custom(format!("error decoding response body, content: {:#?}", String::from_utf8_lossy(&full))))?;
match res {
ResponseData {
code: 0,
Expand Down
40 changes: 37 additions & 3 deletions crates/stream-gears/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use biliup::downloader::construct_headers;
use biliup::downloader::util::Segmentable;
use tracing_subscriber::layer::SubscriberExt;
use biliup::credential::Credential;
use biliup::downloader::extractor::CallbackFn;

#[derive(FromPyObject)]
pub enum PySegment {
Expand All @@ -34,6 +35,18 @@ fn download(
header_map: HashMap<String, String>,
file_name: &str,
segment: PySegment,
) -> PyResult<()> {
download_2(py, url, header_map, file_name, segment, None)
}

#[pyfunction]
fn download_2(
py: Python<'_>,
url: &str,
header_map: HashMap<String, String>,
file_name: &str,
segment: PySegment,
file_name_callback_fn: Option<PyObject>,
) -> PyResult<()> {
py.allow_threads(|| {
let map = construct_headers(header_map);
Expand All @@ -56,13 +69,26 @@ fn download(
.with_timer(local_time)
.with_writer(non_blocking);

let collector = formatting_layer.with(file_layer);
let segment = match segment {
PySegment::Time { time } => Segmentable::new(Some(Duration::from_secs(time)), None),
PySegment::Size { size } => Segmentable::new(None, Some(size)),
};

let file_name_hook = file_name_callback_fn.map(|callback_fn| -> CallbackFn {
Box::new(move |fmt_file_name| {
Python::with_gil(|py| {
match callback_fn.call1(py, (fmt_file_name, )) {
Ok(_) => {}
Err(_) => { tracing::error!("Unable to invoke the callback function.") }
}
})
})
});


let collector = formatting_layer.with(file_layer);
tracing::subscriber::with_default(collector, || -> PyResult<()> {
match biliup::downloader::download(url, map, file_name, segment) {
match biliup::downloader::download(url, map, file_name, segment, file_name_hook) {
Ok(res) => Ok(res),
Err(err) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"{}, {}",
Expand All @@ -73,6 +99,7 @@ fn download(
})
})
}

#[pyfunction]
fn login_by_cookies(file: String) -> PyResult<bool> {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand All @@ -86,6 +113,7 @@ fn login_by_cookies(file: String) -> PyResult<bool> {
))),
}
}

#[pyfunction]
fn send_sms(country_code: u32, phone: u64) -> PyResult<String> {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand All @@ -98,6 +126,7 @@ fn send_sms(country_code: u32, phone: u64) -> PyResult<String> {
))),
}
}

#[pyfunction]
fn login_by_sms(code: u32, ret: String) -> PyResult<bool> {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand All @@ -108,6 +137,7 @@ fn login_by_sms(code: u32, ret: String) -> PyResult<bool> {
Err(_) => Ok(false),
}
}

#[pyfunction]
fn get_qrcode() -> PyResult<String> {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand All @@ -120,6 +150,7 @@ fn get_qrcode() -> PyResult<String> {
))),
}
}

#[pyfunction]
fn login_by_qrcode(ret: String) -> PyResult<String> {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand All @@ -132,6 +163,7 @@ fn login_by_qrcode(ret: String) -> PyResult<String> {
err
)))
}

#[pyfunction]
fn login_by_web_cookies(sess_data: String, bili_jct: String) -> PyResult<bool> {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand All @@ -144,6 +176,7 @@ fn login_by_web_cookies(sess_data: String, bili_jct: String) -> PyResult<bool> {
))),
}
}

#[pyfunction]
fn login_by_web_qrcode(sess_data: String, dede_user_id: String) -> PyResult<bool> {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand Down Expand Up @@ -228,7 +261,7 @@ fn upload(
.build();

match rt.block_on(uploader::upload(studio_pre)) {
Ok(res) => Ok(()),
Ok(_) => Ok(()),
// Ok(_) => { },
Err(err) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"{}, {}",
Expand All @@ -250,6 +283,7 @@ fn stream_gears(m: &Bound<'_, PyModule>) -> PyResult<()> {
// .init();
m.add_function(wrap_pyfunction!(upload, m)?)?;
m.add_function(wrap_pyfunction!(download, m)?)?;
m.add_function(wrap_pyfunction!(download_2, m)?)?;
m.add_function(wrap_pyfunction!(login_by_cookies, m)?)?;
m.add_function(wrap_pyfunction!(send_sms, m)?)?;
m.add_function(wrap_pyfunction!(login_by_qrcode, m)?)?;
Expand Down
Loading

0 comments on commit 6d6a4db

Please sign in to comment.