Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ratelimiting #4

Merged
merged 6 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ SECRET_KEY="" // Your jwt secret key, you can generate one using https://www.grc
// or run this command in your terminal: `head /dev/urandom | shasum -a 256`
HOST="localhost" // Your host
PORT=8080 // Your port
RATE_LIMIT_BURST_SIZE=60 // Optional, default is 60
RATE_LIMIT_PER_SECOND=60 // Optional, default is 5
RATE_LIMIT_BURST_SIZE=30 // Optional, default is 30
RATE_LIMIT_PER_SECOND=60 // Optional, default is 60
API_CONTACT_NAME="" // The name of the API support contact
API_CONTACT_URL="" // The URL of the API support contact, e.g. https://example.com/support or the repo page (:
API_CONTACT_EMAIL="" // The email of the API support contact
Expand Down
98 changes: 9 additions & 89 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ log = "0.4.17"
pretty_env_logger = "0.4.0"
sea-orm = { version = "0.10.2", features = ["runtime-tokio-rustls", "runtime-tokio", "sqlx-sqlite"] }
sqlx-core = "0.6.2"
tokio = { version = "1.21.2", features = ["rt-multi-thread", "macros"] }
chrono = { version = "0.4.22", default-features = false, features = ["time"] }
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.88"
Expand All @@ -30,5 +29,4 @@ thiserror = "1.0.37"
hex = "0.4.3"
utoipa = { version = "2.3.0", features = ["actix_extras"] }
utoipa-swagger-ui = { version = "2.0.1", features = ["actix-web"] }
actix-governor = "0.4.0-beta.1"
governor = "0.4.2"
actix-extensible-rate-limit = {version = "0.2.1", default-features = false, features = ["dashmap"]}
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Just like that, you have a RESTful API running on your machine.
| `SECRET_KEY` | The secret key for JWT | ` ` |
| `HOST` | The host to bind | `localhost` |
| `PORT` | The port to run the server | `8080` |
| `RATE_LIMIT_BURST_SIZE` | The burst size for rate limiter | `60` |
| `RATE_LIMIT_PER_SECOND` | The rate limit per second | `5` |
| `RATE_LIMIT_BURST_SIZE` | The burst size for rate limiter | `30` |
| `RATE_LIMIT_PER_SECOND` | The time to reset the burst | `60` |
| `API_CONTACT_NAME` | The name of the API contact | ` ` |
| `API_CONTACT_URL` | The url of the API contact | ` ` |
| `API_CONTACT_EMAIL` | The email of the API contact | ` ` |
Expand Down
6 changes: 3 additions & 3 deletions api-desc.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ The API has a rate limit of 60 burst requests, and 1 request per 5 seconds, if y
### Headers
<!-- The ratelimit headers -->

- `X-RateLimit-Limit`: Your burst limit.
- `X-RateLimit-Remaining`: The number of requests remaining in the current burst. Will return `Too Many Requests` if you exceed the limit.
- `X-RateLimit-Reset`: The number of seconds left to add a new request to the burst.
- `x-ratelimit-limit`: Your burst limit.
- `x-ratelimit-remaining`: The number of requests remaining in the current burst. Will return `Too Many Requests` if you exceed the limit.
- `x-ratelimit-reset`: The number of seconds left to reset the burst.
10 changes: 6 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::Path;

use actix_governor::Governor;
use actix_extensible_rate_limit::backend::memory::InMemoryBackend;
use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
use migration::{Migrator, MigratorTrait};
Expand Down Expand Up @@ -28,7 +28,7 @@ pub async fn enishalize_poll() -> DatabaseConnection {
.expect("Failed to create database connection pool")
}

#[tokio::main]
#[actix_web::main]
async fn main() -> std::io::Result<()> {
dotenv::dotenv().ok();
pretty_env_logger::init();
Expand All @@ -48,12 +48,14 @@ async fn main() -> std::io::Result<()> {
);
log::info!("Swagger UI is available at http://{}/docs/swagger/", addr);

let ip_limit_config = crate::ratelimit::init_ip();
let ratelimit_backend = InMemoryBackend::builder().build();

HttpServer::new(move || {
let ratelimit_middleware_builder = ratelimit::init_ip(ratelimit_backend.clone());

App::new()
.app_data(web::Data::new(pool.clone()))
.wrap(Governor::new(&ip_limit_config))
.wrap(ratelimit_middleware_builder.build())
.wrap(Logger::default())
.service(web::scope("/api").configure(auth::init_routes))
.service(
Expand Down
85 changes: 36 additions & 49 deletions src/ratelimit.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,46 @@
use std::env;
use std::{env, future::Ready, time::Duration};

use actix_governor::{GovernorConfig, GovernorConfigBuilder};
use actix_web::{dev::ServiceRequest, HttpResponse, HttpResponseBuilder};
use governor::{
clock::{Clock, DefaultClock, QuantaInstant},
middleware::StateInformationMiddleware,
NotUntil,
use actix_extensible_rate_limit::{
backend::{memory::InMemoryBackend, SimpleInput, SimpleInputFunctionBuilder, SimpleOutput},
HeaderCompatibleOutput, RateLimiter, RateLimiterBuilder,
};
use actix_web::{dev::ServiceRequest, http::StatusCode, HttpResponse};

use crate::errors::TodoError as TodoErrorTrait;
use crate::{errors::Error as TodoError, schemas::errors::ErrorSchema};

#[derive(Debug, Clone)]
pub struct IpAddressExtractor;
/// The response error for rate limit exceeded
fn rate_limit_exceeded(rate_info: &SimpleOutput) -> HttpResponse {
let rest_time = rate_info.seconds_until_reset();
let body = ErrorSchema::from(TodoError::TooManyRequests(rest_time));

impl actix_governor::KeyExtractor for IpAddressExtractor {
type Key = String;
type KeyExtractionError = TodoError;

fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
req.connection_info()
.peer_addr()
.map(Into::into)
.server_err("Could not get IP address")
}

fn exceed_rate_limit_response(
&self,
negative: &NotUntil<QuantaInstant>,
mut res: HttpResponseBuilder,
) -> HttpResponse {
let wait_time = negative
.wait_time_from(DefaultClock::default().now())
.as_secs();
res.json(ErrorSchema::from(TodoError::TooManyRequests(wait_time)))
}
HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
.append_header(("x-ratelimit-limit", rate_info.limit()))
.append_header(("x-ratelimit-remaining", rate_info.remaining()))
.append_header(("x-ratelimit-reset", rest_time))
.json(body)
}

/// Initializes IP rate limiter
pub fn init_ip() -> GovernorConfig<IpAddressExtractor, StateInformationMiddleware> {
GovernorConfigBuilder::default()
.key_extractor(IpAddressExtractor)
.per_second(
env::var("RATE_LIMIT_PER_SECOND")
.unwrap_or_else(|_| "5".to_string())
.parse()
.expect("Invalid rate limit per second"),
)
.burst_size(
env::var("RATE_LIMIT_BURST_SIZE")
.unwrap_or_else(|_| "60".to_string())
.parse()
.expect("Invalid rate limit burst size"),
)
.use_headers()
.finish()
.unwrap()
pub fn init_ip(
backend: InMemoryBackend,
) -> RateLimiterBuilder<
InMemoryBackend,
SimpleOutput,
impl Fn(&ServiceRequest) -> Ready<Result<SimpleInput, actix_web::Error>>,
> {
let seconds: u64 = env::var("RATE_LIMIT_PER_SECOND")
.unwrap_or_else(|_| "60".to_string())
.parse()
.expect("Invalid rate limit per second");
let burst_size: u64 = env::var("RATE_LIMIT_BURST_SIZE")
.unwrap_or_else(|_| "30".to_string())
.parse()
.expect("Invalid rate limit burst size");
// Assign a limit of `burst_size` requests per `seconds` seconds per client ip address
let input = SimpleInputFunctionBuilder::new(Duration::from_secs(seconds), burst_size)
.real_ip_key()
.build();
RateLimiter::builder(backend, input)
.add_headers()
.request_denied_response(rate_limit_exceeded)
}