Skip to content

Commit

Permalink
Feat: Add retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
PeiPei233 committed Jan 2, 2024
1 parent 55daf03 commit 9bff86e
Showing 1 changed file with 92 additions and 58 deletions.
150 changes: 92 additions & 58 deletions src-tauri/src/zju_assist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ use num_bigint::BigUint;
use percent_encoding::percent_decode_str;
use regex::Regex;
use reqwest::cookie::{CookieStore, Jar};
use reqwest::header::{
HeaderMap, AUTHORIZATION, CONNECTION, HOST, ORIGIN, REFERER, UPGRADE_INSECURE_REQUESTS,
USER_AGENT,
};
use reqwest::{Client, RequestBuilder, Response};
use reqwest::header::{HeaderMap, AUTHORIZATION, USER_AGENT};
use reqwest::{Client, Method, RequestBuilder, Response};
use reqwest::{Error, IntoUrl};
use serde::Serialize;
use serde_json::Value;
use std::path::PathBuf;
use std::sync::Arc;
Expand All @@ -20,13 +18,17 @@ use crate::model::Subject;

#[derive(Clone)]
pub struct ZjuAssist {
client: Client,
jar: Arc<Jar>,
have_login: bool,
}

impl ZjuAssist {
pub fn new() -> Self {
pub struct ZjuRequestBuilder {
request_builder: RequestBuilder,
request_builder_no_proxy: RequestBuilder,
}

impl ZjuRequestBuilder {
fn new<U: IntoUrl + Clone>(client: ZjuAssist, method: Method, url: U) -> Self {
let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
Expand All @@ -35,16 +37,76 @@ impl ZjuAssist {
.unwrap(),
);

let jar = Arc::new(Jar::default());
let client = Client::builder()
.cookie_provider(Arc::clone(&jar))
let client_default = Client::builder()
.cookie_provider(Arc::clone(&client.jar))
.default_headers(headers.clone())
.build()
.unwrap();

let client_no_proxy = Client::builder()
.cookie_provider(Arc::clone(&client.jar))
.default_headers(headers)
.no_proxy()
.build()
.unwrap();

Self {
client,
jar,
request_builder: client_default.request(method.clone(), url.clone()),
request_builder_no_proxy: client_no_proxy.request(method, url),
}
}

pub fn headers(&mut self, headers: HeaderMap) -> &mut Self {
self.request_builder = self
.request_builder
.try_clone()
.unwrap()
.headers(headers.clone());
self.request_builder_no_proxy = self
.request_builder_no_proxy
.try_clone()
.unwrap()
.headers(headers.clone());
self
}

pub fn form<T: Serialize + ?Sized>(&mut self, form: &T) -> &mut Self {
self.request_builder = self.request_builder.try_clone().unwrap().form(form);
self.request_builder_no_proxy = self
.request_builder_no_proxy
.try_clone()
.unwrap()
.form(form);
self
}

pub async fn send(&self) -> Result<Response, Error> {
// total 6 retries, 3 with proxy, 3 without proxy
let mut res = self.request_builder.try_clone().unwrap().send().await;
let mut retries = 5;

while res.is_err() && retries > 0 {
retries -= 1;
if retries % 2 == 0 {
res = self
.request_builder_no_proxy
.try_clone()
.unwrap()
.send()
.await;
} else {
res = self.request_builder.try_clone().unwrap().send().await;
}
}

res
}
}

impl ZjuAssist {
pub fn new() -> Self {
Self {
jar: Arc::new(Jar::default()),
have_login: false,
}
}
Expand All @@ -64,14 +126,20 @@ impl ZjuAssist {
.collect()
}

pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn request<U: IntoUrl + Clone>(&self, method: Method, url: U) -> ZjuRequestBuilder {
ZjuRequestBuilder::new(self.clone(), method, url)
}

pub fn get<U: IntoUrl + Clone>(&self, url: U) -> ZjuRequestBuilder {
info!("GET {}", url.as_str());
self.client.get(url)
// self.client.get(url)
self.request(Method::GET, url)
}

pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn post<U: IntoUrl + Clone>(&self, url: U) -> ZjuRequestBuilder {
info!("POST {}", url.as_str());
self.client.post(url)
// self.client.post(url)
self.request(Method::POST, url)
}

pub async fn login(
Expand All @@ -83,20 +151,8 @@ impl ZjuAssist {
return Ok(());
}

let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
"Mozilla/5.0 (X11; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0"
.parse()
.unwrap(),
);
headers.insert(CONNECTION, "keep-alive".parse().unwrap());
headers.insert(HOST, "zjuam.zju.edu.cn".parse().unwrap());
headers.insert(UPGRADE_INSECURE_REQUESTS, "1".parse().unwrap());

let res = self
.get("https://zjuam.zju.edu.cn/cas/login")
.headers(headers.clone())
.send()
.await?;

Expand All @@ -105,7 +161,6 @@ impl ZjuAssist {
self.logout();
let res = self
.get("https://zjuam.zju.edu.cn/cas/login")
.headers(headers.clone())
.send()
.await?;
text = res.text().await?;
Expand All @@ -118,14 +173,8 @@ impl ZjuAssist {
.captures(&text)
.and_then(|cap| cap.get(1).map(|m| m.as_str()))
.ok_or("Execution value not found")?;

headers.insert(
REFERER,
"https://zjuam.zju.edu.cn/cas/login".parse().unwrap(),
);
let res = self
.get("https://zjuam.zju.edu.cn/cas/v2/getPubKey")
.headers(headers.clone())
.send()
.await?;

Expand All @@ -143,15 +192,14 @@ impl ZjuAssist {
("authcode", ""),
];

headers.insert(ORIGIN, "https://zjuam.zju.edu.cn".parse().unwrap());
let res = self
.post("https://zjuam.zju.edu.cn/cas/login")
.form(&data)
.send()
.await?;

if res.text().await?.contains("统一身份认证平台") {
Err("Login failed".into())
Err("Login failed: Wrong username or password".into())
} else {
self.get("https://courses.zju.edu.cn/user/courses")
.send()
Expand All @@ -166,20 +214,7 @@ impl ZjuAssist {
}

pub fn logout(&mut self) {
let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
"Mozilla/5.0 (X11; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0"
.parse()
.unwrap(),
);
let jar = Arc::new(Jar::default());
self.client = Client::builder()
.cookie_provider(Arc::clone(&jar))
.default_headers(headers)
.build()
.unwrap();
self.jar = jar;
self.jar = Arc::new(Jar::default());
self.have_login = false;
}

Expand Down Expand Up @@ -400,10 +435,10 @@ impl ZjuAssist {
let token = re
.captures(&cookie_str)
.and_then(|cap| cap.get(1).map(|m| m.as_str()))
.ok_or("Token not found")?;
.ok_or("Token not found, try log in again")?;
Ok(token.to_string())
} else {
Err("Token not found".into())
Err("Token not found, try log in again".into())
}
}

Expand Down Expand Up @@ -622,9 +657,8 @@ impl ZjuAssist {
}
}

pub async fn get<T: IntoUrl>(url: T) -> Result<Response, Error> {
info!("GET {}", url.as_str());
reqwest::get(url).await
pub async fn get<T: IntoUrl + Clone>(url: T) -> Result<Response, Error> {
ZjuAssist::new().get(url).send().await
}

pub async fn get_ppt_urls(
Expand Down Expand Up @@ -672,7 +706,7 @@ pub async fn download_ppt_image(url: &str, path: &str) -> Result<(), Box<dyn std
while retries < MAX_RETRIES {
let res = get(url).await?;
let content = res.bytes().await?;
if content.is_empty() {
if content.is_empty() || image::guess_format(&content).is_err() {
retries += 1;
continue;
}
Expand Down

0 comments on commit 9bff86e

Please sign in to comment.