scuffle_signal/
bootstrap.rs1use std::sync::Arc;
2
3use scuffle_bootstrap::global::Global;
4use scuffle_bootstrap::service::Service;
5use scuffle_context::ContextFutExt;
6
7#[derive(Default, Debug, Clone, Copy)]
9pub struct SignalSvc;
10
11pub trait SignalConfig: Global {
13 fn signals(&self) -> Vec<crate::SignalKind> {
17 vec![crate::SignalKind::Terminate, crate::SignalKind::Interrupt]
18 }
19
20 fn timeout(&self) -> Option<std::time::Duration> {
22 Some(std::time::Duration::from_secs(30))
23 }
24
25 fn on_shutdown(self: &Arc<Self>) -> impl std::future::Future<Output = anyhow::Result<()>> + Send {
27 std::future::ready(Ok(()))
28 }
29
30 fn on_force_shutdown(
32 &self,
33 signal: Option<crate::SignalKind>,
34 ) -> impl std::future::Future<Output = anyhow::Result<()>> + Send {
35 let err = if let Some(signal) = signal {
36 anyhow::anyhow!("received signal, shutting down immediately: {:?}", signal)
37 } else {
38 anyhow::anyhow!("timeout reached, shutting down immediately")
39 };
40
41 std::future::ready(Err(err))
42 }
43
44 fn block_global_shutdown(&self) -> impl std::future::Future<Output = ()> + Send {
48 scuffle_context::Handler::global().shutdown()
49 }
50}
51
52impl<Global: SignalConfig> Service<Global> for SignalSvc {
53 fn enabled(&self, global: &Arc<Global>) -> impl std::future::Future<Output = anyhow::Result<bool>> + Send {
54 std::future::ready(Ok(!global.signals().is_empty()))
55 }
56
57 async fn run(self, global: Arc<Global>, ctx: scuffle_context::Context) -> anyhow::Result<()> {
58 let timeout = global.timeout();
59
60 let signals = global.signals();
61 anyhow::ensure!(!signals.is_empty(), "no signals to listen for");
62
63 let mut handler = crate::SignalHandler::with_signals(signals);
64
65 handler.recv().with_context(&ctx).await;
67 global.on_shutdown().await?;
68 drop(ctx);
69
70 tokio::select! {
71 signal = handler.recv() => {
72 global.on_force_shutdown(Some(signal)).await?;
73 },
74 _ = global.block_global_shutdown() => {}
75 Some(()) = async {
76 if let Some(timeout) = timeout {
77 tokio::time::sleep(timeout).await;
78 Some(())
79 } else {
80 None
81 }
82 } => {
83 global.on_force_shutdown(None).await?;
84 },
85 };
86
87 Ok(())
88 }
89}
90
91#[cfg(test)]
92#[cfg_attr(all(coverage_nightly, test), coverage(off))]
93mod test {
94 use std::sync::Arc;
95
96 use scuffle_bootstrap::{GlobalWithoutConfig, Service};
97 use scuffle_future_ext::FutureExt;
98
99 use super::SignalConfig;
100 use crate::tests::raise_signal;
101 use crate::{SignalKind, SignalSvc};
102
103 async fn force_shutdown_two_signals<Global: GlobalWithoutConfig + SignalConfig>() {
104 let (ctx, handler) = scuffle_context::Context::new();
105
106 let _global_ctx = scuffle_context::Context::global();
107
108 let svc = SignalSvc;
109 let global = <Global as GlobalWithoutConfig>::init().await.unwrap();
110
111 assert!(svc.enabled(&global).await.unwrap());
112 let result = tokio::spawn(svc.run(global, ctx));
113
114 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
116
117 raise_signal(SignalKind::Interrupt).await;
118 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
119 raise_signal(SignalKind::Interrupt).await;
120 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
121
122 match result.with_timeout(tokio::time::Duration::from_millis(1000)).await {
123 Ok(Ok(Err(e))) => {
124 assert_eq!(e.to_string(), "received signal, shutting down immediately: Interrupt");
125 }
126 r => panic!("unexpected result: {r:?}"),
127 }
128
129 assert!(
130 handler
131 .shutdown()
132 .with_timeout(tokio::time::Duration::from_millis(1000))
133 .await
134 .is_ok()
135 );
136 }
137
138 struct TestGlobal;
139
140 impl GlobalWithoutConfig for TestGlobal {
141 fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
142 std::future::ready(Ok(Arc::new(Self)))
143 }
144 }
145
146 impl SignalConfig for TestGlobal {
147 async fn block_global_shutdown(&self) {
148 std::future::pending().await
149 }
150 }
151
152 #[tokio::test]
153 #[cfg(not(valgrind))]
154 async fn default_bootstrap_service() {
155 force_shutdown_two_signals::<TestGlobal>().await;
156 }
157
158 struct NoTimeoutTestGlobal(tokio::sync::Notify);
159
160 impl GlobalWithoutConfig for NoTimeoutTestGlobal {
161 fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
162 std::future::ready(Ok(Arc::new(Self(tokio::sync::Notify::new()))))
163 }
164 }
165
166 impl SignalConfig for NoTimeoutTestGlobal {
167 fn timeout(&self) -> Option<std::time::Duration> {
168 None
169 }
170
171 async fn block_global_shutdown(&self) {
173 self.0.notified().await;
174 }
175 }
176
177 #[tokio::test]
178 #[cfg(not(valgrind))]
179 async fn bootstrap_service_no_timeout() {
180 let (ctx, handler) = scuffle_context::Context::new();
181 let svc = SignalSvc;
182 let global = <NoTimeoutTestGlobal as GlobalWithoutConfig>::init().await.unwrap();
183
184 assert!(svc.enabled(&global).await.unwrap());
185 let mut result = tokio::spawn(svc.run(global.clone(), ctx));
186
187 println!("waiting for service to start");
189 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
190
191 raise_signal(SignalKind::Interrupt).await;
192 assert!(
194 (&mut result)
195 .with_timeout(tokio::time::Duration::from_millis(100))
196 .await
197 .is_err()
198 );
199
200 global.0.notify_one();
201
202 assert!(result.with_timeout(tokio::time::Duration::from_millis(100)).await.is_ok());
203
204 assert!(
205 handler
206 .shutdown()
207 .with_timeout(tokio::time::Duration::from_millis(1000))
208 .await
209 .is_ok()
210 );
211 }
212
213 #[tokio::test]
214 #[cfg(not(valgrind))]
215 async fn bootstrap_service_force_shutdown() {
216 force_shutdown_two_signals::<NoTimeoutTestGlobal>().await;
217 }
218
219 struct NoSignalsTestGlobal;
220
221 impl GlobalWithoutConfig for NoSignalsTestGlobal {
222 fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
223 std::future::ready(Ok(Arc::new(Self)))
224 }
225 }
226
227 impl SignalConfig for NoSignalsTestGlobal {
228 fn signals(&self) -> Vec<crate::SignalKind> {
229 vec![]
230 }
231
232 fn timeout(&self) -> Option<std::time::Duration> {
233 None
234 }
235
236 async fn block_global_shutdown(&self) {
237 std::future::pending().await
238 }
239 }
240
241 #[tokio::test]
242 async fn bootstrap_service_no_signals() {
243 let (ctx, handler) = scuffle_context::Context::new();
244 let svc = SignalSvc;
245 let global = <NoSignalsTestGlobal as GlobalWithoutConfig>::init().await.unwrap();
246
247 assert!(!svc.enabled(&global).await.unwrap());
248 let result = svc.run(global, ctx).await.unwrap_err();
249
250 assert_eq!(result.to_string(), "no signals to listen for");
251
252 assert!(
253 handler
254 .shutdown()
255 .with_timeout(tokio::time::Duration::from_millis(1000))
256 .await
257 .is_ok()
258 );
259 }
260
261 struct SmallTimeoutTestGlobal;
262
263 impl GlobalWithoutConfig for SmallTimeoutTestGlobal {
264 fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
265 std::future::ready(Ok(Arc::new(Self)))
266 }
267 }
268
269 impl SignalConfig for SmallTimeoutTestGlobal {
270 fn timeout(&self) -> Option<std::time::Duration> {
271 Some(std::time::Duration::from_millis(50))
272 }
273
274 async fn block_global_shutdown(&self) {
275 std::future::pending().await
276 }
277 }
278
279 #[tokio::test]
280 #[cfg(not(valgrind))]
281 async fn bootstrap_service_timeout_force_shutdown() {
282 let (ctx, handler) = scuffle_context::Context::new();
283
284 let _global_ctx = scuffle_context::Context::global();
286
287 let svc = SignalSvc;
288 let global = <SmallTimeoutTestGlobal as GlobalWithoutConfig>::init().await.unwrap();
289
290 assert!(svc.enabled(&global).await.unwrap());
291 let result = tokio::spawn(svc.run(global, ctx));
292
293 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
295
296 raise_signal(crate::SignalKind::Interrupt).await;
297
298 match result.with_timeout(tokio::time::Duration::from_millis(1000)).await {
299 Ok(Ok(Err(e))) => {
300 assert_eq!(e.to_string(), "timeout reached, shutting down immediately");
301 }
302 _ => panic!("unexpected result"),
303 }
304
305 assert!(
306 handler
307 .shutdown()
308 .with_timeout(tokio::time::Duration::from_millis(1000))
309 .await
310 .is_ok()
311 );
312 }
313}