scuffle_http/backend/hyper/
mod.rs1use std::fmt::Debug;
3use std::net::SocketAddr;
4
5use scuffle_context::ContextFutExt;
6#[cfg(feature = "tracing")]
7use tracing::Instrument;
8
9use crate::error::HttpError;
10use crate::service::{HttpService, HttpServiceFactory};
11
12mod handler;
13mod stream;
14mod utils;
15
16#[derive(Debug, Clone, bon::Builder)]
22pub struct HyperBackend<F> {
23 #[builder(default = scuffle_context::Context::global())]
25 ctx: scuffle_context::Context,
26 #[builder(default = 1)]
28 worker_tasks: usize,
29 service_factory: F,
31 bind: SocketAddr,
36 #[cfg(feature = "tls-rustls")]
41 rustls_config: Option<rustls::ServerConfig>,
42 #[cfg(feature = "http1")]
44 #[builder(default = true)]
45 http1_enabled: bool,
46 #[cfg(feature = "http2")]
48 #[builder(default = true)]
49 http2_enabled: bool,
50}
51
52impl<F> HyperBackend<F>
53where
54 F: HttpServiceFactory + Clone + Send + 'static,
55 F::Error: std::error::Error + Send,
56 F::Service: Clone + Send + 'static,
57 <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
58 <F::Service as HttpService>::ResBody: Send,
59 <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
60 <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
61{
62 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
66 #[allow(unused_mut)] pub async fn run(mut self) -> Result<(), HttpError<F>> {
68 #[cfg(feature = "tracing")]
69 tracing::debug!("starting server");
70
71 #[cfg(feature = "tls-rustls")]
74 if let Some(rustls_config) = self.rustls_config.as_mut() {
75 rustls_config.max_early_data_size = 0;
76 }
77
78 let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?;
80
81 #[cfg(feature = "tls-rustls")]
82 let tls_acceptor = self
83 .rustls_config
84 .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
85
86 let (worker_ctx, worker_handler) = self.ctx.new_child();
88
89 let workers = (0..self.worker_tasks)
90 .map(|_n| {
91 let service_factory = self.service_factory.clone();
92 let ctx = worker_ctx.clone();
93 let std_listener = listener.try_clone()?;
94 let listener = tokio::net::TcpListener::from_std(std_listener)?;
95 #[cfg(feature = "tls-rustls")]
96 let tls_acceptor = tls_acceptor.clone();
97
98 let worker_fut = async move {
99 loop {
100 #[cfg(feature = "tracing")]
101 tracing::trace!("waiting for connections");
102
103 let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
104 Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
105 Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
106 #[cfg(feature = "tracing")]
107 tracing::error!(err = %e, "failed to accept tcp connection");
108 return Err(HttpError::<F>::from(e));
109 }
110 Some(Err(_)) => continue,
111 None => {
112 #[cfg(feature = "tracing")]
113 tracing::trace!("context done, stopping listener");
114 break;
115 }
116 };
117
118 #[cfg(feature = "tracing")]
119 tracing::trace!(addr = %addr, "accepted tcp connection");
120
121 let ctx = ctx.clone();
122 #[cfg(feature = "tls-rustls")]
123 let tls_acceptor = tls_acceptor.clone();
124 let mut service_factory = service_factory.clone();
125
126 let connection_fut = async move {
127 #[cfg(feature = "tls-rustls")]
129 if let Some(tls_acceptor) = tls_acceptor {
130 #[cfg(feature = "tracing")]
131 tracing::trace!("accepting tls connection");
132
133 stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
134 Some(Ok(stream)) => stream,
135 Some(Err(_err)) => {
136 #[cfg(feature = "tracing")]
137 tracing::warn!(err = %_err, "failed to accept tls connection");
138 return;
139 }
140 None => {
141 #[cfg(feature = "tracing")]
142 tracing::trace!("context done, stopping tls acceptor");
143 return;
144 }
145 };
146
147 #[cfg(feature = "tracing")]
148 tracing::trace!("accepted tls connection");
149 }
150
151 let http_service = match service_factory.new_service(addr).await {
153 Ok(service) => service,
154 Err(_e) => {
155 #[cfg(feature = "tracing")]
156 tracing::warn!(err = %_e, "failed to create service");
157 return;
158 }
159 };
160
161 #[cfg(feature = "tracing")]
162 tracing::trace!("handling connection");
163
164 #[cfg(feature = "http1")]
165 let http1 = self.http1_enabled;
166 #[cfg(not(feature = "http1"))]
167 let http1 = false;
168
169 #[cfg(feature = "http2")]
170 let http2 = self.http2_enabled;
171 #[cfg(not(feature = "http2"))]
172 let http2 = false;
173
174 let _res = handler::handle_connection::<F, _, _>(ctx, http_service, stream, http1, http2).await;
175
176 #[cfg(feature = "tracing")]
177 if let Err(e) = _res {
178 tracing::warn!(err = %e, "error handling connection");
179 }
180
181 #[cfg(feature = "tracing")]
182 tracing::trace!("connection closed");
183 };
184
185 #[cfg(feature = "tracing")]
186 let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
187
188 tokio::spawn(connection_fut);
189 }
190
191 #[cfg(feature = "tracing")]
192 tracing::trace!("listener closed");
193
194 Ok(())
195 };
196
197 #[cfg(feature = "tracing")]
198 let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
199
200 Ok(tokio::spawn(worker_fut))
201 })
202 .collect::<std::io::Result<Vec<_>>>()?;
203
204 match futures::future::try_join_all(workers).await {
205 Ok(res) => {
206 for r in res {
207 if let Err(e) = r {
208 drop(worker_ctx);
209 worker_handler.shutdown().await;
210 return Err(e);
211 }
212 }
213 }
214 Err(_e) => {
215 #[cfg(feature = "tracing")]
216 tracing::error!(err = %_e, "error running workers");
217 }
218 }
219
220 drop(worker_ctx);
221 worker_handler.shutdown().await;
222
223 #[cfg(feature = "tracing")]
224 tracing::debug!("all workers finished");
225
226 Ok(())
227 }
228}