use std::{ collections::HashSet, future::Future, net::{IpAddr, SocketAddr}, pin::Pin, sync::Arc, }; use anyhow::{Context, Result}; use askama::Template; use axum::{ extract::{ConnectInfo, State}, http::HeaderMap, routing::get, Router, }; use tokio::{fs::File, io::AsyncWriteExt, net::TcpListener, sync::RwLock}; async fn routes() -> Router { Router::new() .route("/", get(home)) .with_state(AppState::load().await) } #[derive(Clone)] pub struct AppState { view_count: Arc>, seen_ips: Arc>>, } impl AppState { async fn load() -> Self { let view_count = match read_count().await { Ok(v) => v, Err(e) => { tracing::error!(message = "Failed to read count", error = e.to_string(),); 0 } }; let view_count = Arc::new(RwLock::new(view_count)); let seen_ips = Arc::default(); Self { view_count, seen_ips, } } async fn increment(&self, ip: IpAddr) -> u32 { if self.seen_ips.read().await.contains(&ip) { let view_count = self.view_count.read().await; return *view_count; } self.seen_ips.write().await.insert(ip); let mut view_count = self.view_count.write().await; *view_count += 1; if let Err(e) = write_count(*view_count).await { tracing::error!(message = "Failed to write count", error = e.to_string(),); } *view_count } } #[derive(Template)] #[template(path = "index.html")] pub struct HomeTemplate { view_count: u32, } fn state_path() -> anyhow::Result { let project_dirs = directories::ProjectDirs::from("com", "lelgenio", "made-you-look") .context("building project dirs path")?; let state_dir = project_dirs.state_dir().context("getting state dir")?; Ok(state_dir.join("state")) } async fn read_count() -> anyhow::Result { let file_path = state_path()?; let s = tokio::fs::read_to_string(file_path) .await .context("reading count")?; s.parse().context("parsing count") } async fn write_count(new_count: u32) -> anyhow::Result<()> { let file_path = state_path()?; // let s = std::fs::read_to_string(file_path).context("reading count")?; tokio::fs::create_dir_all(file_path.parent().context("Getting state file dirname")?).await?; let mut file = File::create(file_path).await.context("Creating")?; file.write(new_count.to_string().as_bytes()) .await .context("Writting new count")?; Ok(()) } fn ip_from_headers(headers: &HeaderMap) -> Option { let value = headers.get("x-forwarded-for")?; let ip_str = String::from_utf8_lossy(value.as_bytes()); ip_str.parse::().ok() } #[axum::debug_handler] pub async fn home( State(state): State, headers: HeaderMap, ConnectInfo(addr): ConnectInfo, ) -> HomeTemplate { let direct_ip = addr.ip(); let ip: IpAddr = ip_from_headers(&headers).unwrap_or(direct_ip); let view_count = state.increment(ip).await; HomeTemplate { view_count } } pub struct Config { pub port: u16, } pub struct RunningServer { pub port: u16, pub server: Pin> + Send>>, } pub async fn run(config: Config) -> Result { setup_tracing(); let router = routes() .await .layer(tower_http::trace::TraceLayer::new_for_http()) .into_make_service_with_connect_info::(); let tcp_listener = TcpListener::bind(format!("0.0.0.0:{}", config.port)).await?; let port = tcp_listener.local_addr()?.port(); tracing::info!("Listening on http://localhost:{port}"); let server = Box::pin(async move { axum::serve(tcp_listener, router).await?; Ok(()) }); Ok(RunningServer { port, server }) } pub fn setup_tracing() { use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; let log_filter = std::env::var("MADE_YOU_LOOK_LOG").unwrap_or_else(|_| "made_you_look=debug,warn".into()); eprintln!("RUST_LOG: {log_filter}"); tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new(log_filter)) .with(tracing_subscriber::fmt::layer()) .try_init() .ok(); }