You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
3.0 KiB

//! Banning
use std::sync::{
Arc,
Weak,
};
use std::{fmt, error};
use std::collections::BTreeSet;
use std::cmp::{PartialOrd, Ordering};
use tokio::sync::RwLock;
use std::net::IpAddr;
use std::iter::FromIterator;
use super::*;
use source::ClientInfo;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum BanKind
{
IP(IpAddr),
}
impl PartialEq<ClientInfo> for BanKind
{
fn eq(&self, other: &ClientInfo) -> bool
{
match self {
Self::IP(addr) => addr == &other.ip_addr,
}
}
}
impl PartialOrd<ClientInfo> for BanKind
{
#[inline] fn partial_cmp(&self, other: &ClientInfo) -> Option<Ordering> {
match self {
Self::IP(addr) => addr.partial_cmp(&other.ip_addr),
}
}
}
impl ClientInfo
{
/// Ban this client IP
#[inline] pub fn ban_ip(&self) -> BanKind
{
BanKind::IP(self.ip_addr.clone())
}
}
#[derive(Debug, Clone)]
pub struct Banlist(Arc<RwLock<BTreeSet<BanKind>>>);
//type OpaqueFuture<'a, T> = impl Future<Output = T> + 'a;
impl Banlist
{
/// Create a new, empty banlist.
///
/// To create one from a list of bans, this type implements `FromIterator`.
pub fn new() -> Self
{
Self(Arc::new(RwLock::new(BTreeSet::new())))
}
/// Add a ban to the list and wait for it to complete.
pub async fn add_ban_inline(&self, ban: BanKind) -> bool
{
self.0.write().await.insert(ban)
}
/// Add a ban to the list.
///
/// If the list is being used, this operation is deferred until it is able to complete. A future is returned to allow you to wait until the operation completes (it is a background task.)
pub fn add_ban(&self, ban: BanKind) -> impl Future<Output = bool>
{
let col = self.0.clone();
tokio::spawn(async move {
col.write().await.insert(ban)
}).map(|x| x.unwrap())
}
/// Create a warp filter that disallows hosts on the list.
///
/// # Lifetime
/// The filter holds a weak reference to this list. If the list is dropped, then the filter will panic.
pub fn filter(&self, client_info: ClientInfo) -> impl Future<Output = Result<ClientInfo, warp::reject::Rejection>> + 'static
{
let refer = Arc::downgrade(&self.0);
async move {
let refer = refer.upgrade().unwrap();
let bans = refer.read().await;
//XXX: FUCK there has to be a better way to check this right?
for fuck in bans.iter()
{
if fuck == &client_info {
return Err(warp::reject::custom(ClientBannedError));
}
}
Ok(client_info)
}
}
}
impl FromIterator<BanKind> for Banlist
{
fn from_iter<I: IntoIterator<Item=BanKind>>(iter: I) -> Self
{
Self(Arc::new(RwLock::new(iter.into_iter().collect())))
}
}
/// Error returned from filter when the client is banned.
#[derive(Debug)]
pub struct ClientBannedError;
impl warp::reject::Reject for ClientBannedError{}
impl error::Error for ClientBannedError{}
impl fmt::Display for ClientBannedError
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "you are banned")
}
}