172 lines
4.4 KiB
Rust
172 lines
4.4 KiB
Rust
use std::{
|
|
collections::HashSet,
|
|
future::Future,
|
|
net::{IpAddr, SocketAddr},
|
|
pin::Pin,
|
|
sync::Arc,
|
|
};
|
|
|
|
use anyhow::{Context, Result};
|
|
use askama::Template;
|
|
use axum::{
|
|
extract::{ConnectInfo, State},
|
|
http::HeaderMap,
|
|
routing::get,
|
|
Router,
|
|
};
|
|
use tokio::{fs::File, io::AsyncWriteExt, net::TcpListener, sync::RwLock};
|
|
|
|
async fn routes() -> Router {
|
|
Router::new()
|
|
.route("/", get(home))
|
|
.with_state(AppState::load().await)
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
view_count: Arc<RwLock<u32>>,
|
|
seen_ips: Arc<RwLock<HashSet<IpAddr>>>,
|
|
}
|
|
|
|
impl AppState {
|
|
async fn load() -> Self {
|
|
let view_count = match read_count().await {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
tracing::error!(message = "Failed to read count", error = e.to_string(),);
|
|
0
|
|
}
|
|
};
|
|
|
|
let view_count = Arc::new(RwLock::new(view_count));
|
|
let seen_ips = Arc::default();
|
|
|
|
Self {
|
|
view_count,
|
|
seen_ips,
|
|
}
|
|
}
|
|
|
|
async fn increment(&self, ip: IpAddr) -> u32 {
|
|
if self.seen_ips.read().await.contains(&ip) {
|
|
let view_count = self.view_count.read().await;
|
|
return *view_count;
|
|
}
|
|
self.seen_ips.write().await.insert(ip);
|
|
|
|
let mut view_count = self.view_count.write().await;
|
|
*view_count += 1;
|
|
|
|
if let Err(e) = write_count(*view_count).await {
|
|
tracing::error!(message = "Failed to write count", error = e.to_string(),);
|
|
}
|
|
|
|
*view_count
|
|
}
|
|
}
|
|
|
|
#[derive(Template)]
|
|
#[template(path = "index.html")]
|
|
pub struct HomeTemplate {
|
|
view_count: u32,
|
|
}
|
|
|
|
fn state_path() -> anyhow::Result<std::path::PathBuf> {
|
|
let project_dirs = directories::ProjectDirs::from("com", "lelgenio", "made-you-look")
|
|
.context("building project dirs path")?;
|
|
|
|
let state_dir = project_dirs.state_dir().context("getting state dir")?;
|
|
|
|
Ok(state_dir.join("state"))
|
|
}
|
|
|
|
async fn read_count() -> anyhow::Result<u32> {
|
|
let file_path = state_path()?;
|
|
let s = tokio::fs::read_to_string(file_path)
|
|
.await
|
|
.context("reading count")?;
|
|
|
|
s.parse().context("parsing count")
|
|
}
|
|
|
|
async fn write_count(new_count: u32) -> anyhow::Result<()> {
|
|
let file_path = state_path()?;
|
|
|
|
// let s = std::fs::read_to_string(file_path).context("reading count")?;
|
|
tokio::fs::create_dir_all(file_path.parent().context("Getting state file dirname")?).await?;
|
|
let mut file = File::create(file_path).await.context("Creating")?;
|
|
|
|
file.write(new_count.to_string().as_bytes())
|
|
.await
|
|
.context("Writting new count")?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn ip_from_headers(headers: &HeaderMap) -> Option<IpAddr> {
|
|
let value = headers.get("x-forwarded-for")?;
|
|
let ip_str = String::from_utf8_lossy(value.as_bytes());
|
|
|
|
ip_str.parse::<IpAddr>().ok()
|
|
}
|
|
|
|
#[axum::debug_handler]
|
|
pub async fn home(
|
|
State(state): State<AppState>,
|
|
headers: HeaderMap,
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
) -> HomeTemplate {
|
|
let direct_ip = addr.ip();
|
|
|
|
let ip: IpAddr = ip_from_headers(&headers).unwrap_or(direct_ip);
|
|
|
|
let view_count = state.increment(ip).await;
|
|
|
|
HomeTemplate { view_count }
|
|
}
|
|
|
|
pub struct Config {
|
|
pub port: u16,
|
|
}
|
|
|
|
pub struct RunningServer {
|
|
pub port: u16,
|
|
pub server: Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>,
|
|
}
|
|
|
|
pub async fn run(config: Config) -> Result<RunningServer> {
|
|
setup_tracing();
|
|
|
|
let router = routes()
|
|
.await
|
|
.layer(tower_http::trace::TraceLayer::new_for_http())
|
|
.into_make_service_with_connect_info::<SocketAddr>();
|
|
|
|
let tcp_listener = TcpListener::bind(format!("0.0.0.0:{}", config.port)).await?;
|
|
|
|
let port = tcp_listener.local_addr()?.port();
|
|
|
|
tracing::info!("Listening on http://localhost:{port}");
|
|
let server = Box::pin(async move {
|
|
axum::serve(tcp_listener, router).await?;
|
|
Ok(())
|
|
});
|
|
|
|
Ok(RunningServer { port, server })
|
|
}
|
|
|
|
pub fn setup_tracing() {
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
|
|
|
let log_filter =
|
|
std::env::var("MADE_YOU_LOOK_LOG").unwrap_or_else(|_| "made_you_look=debug,warn".into());
|
|
|
|
eprintln!("RUST_LOG: {log_filter}");
|
|
|
|
tracing_subscriber::registry()
|
|
.with(tracing_subscriber::EnvFilter::new(log_filter))
|
|
.with(tracing_subscriber::fmt::layer())
|
|
.try_init()
|
|
.ok();
|
|
}
|