Skip to main content

tower/timeout/
mod.rs

1//! Middleware that applies a timeout to requests.
2//!
3//! If the response does not complete within the specified timeout, the response
4//! will be aborted.
5
6pub mod error;
7pub mod future;
8mod layer;
9
10pub use self::layer::TimeoutLayer;
11
12use self::future::ResponseFuture;
13use std::task::{Context, Poll};
14use std::time::Duration;
15use tower_service::Service;
16
17/// Applies a timeout to requests.
18#[derive(Debug, Clone)]
19pub struct Timeout<T> {
20    inner: T,
21    timeout: Duration,
22}
23
24// ===== impl Timeout =====
25
26impl<T> Timeout<T> {
27    /// Creates a new [`Timeout`]
28    pub const fn new(inner: T, timeout: Duration) -> Self {
29        Timeout { inner, timeout }
30    }
31
32    /// Get a reference to the inner service
33    pub fn get_ref(&self) -> &T {
34        &self.inner
35    }
36
37    /// Get a mutable reference to the inner service
38    pub fn get_mut(&mut self) -> &mut T {
39        &mut self.inner
40    }
41
42    /// Consume `self`, returning the inner service
43    pub fn into_inner(self) -> T {
44        self.inner
45    }
46}
47
48impl<S, Request> Service<Request> for Timeout<S>
49where
50    S: Service<Request>,
51    S::Error: Into<crate::BoxError>,
52{
53    type Response = S::Response;
54    type Error = crate::BoxError;
55    type Future = ResponseFuture<S::Future>;
56
57    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58        match self.inner.poll_ready(cx) {
59            Poll::Pending => Poll::Pending,
60            Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
61        }
62    }
63
64    fn call(&mut self, request: Request) -> Self::Future {
65        let response = self.inner.call(request);
66        let sleep = tokio::time::sleep(self.timeout);
67
68        ResponseFuture::new(response, sleep)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use std::{
76        convert::Infallible,
77        future::Future,
78        pin::Pin,
79        task::{Context, Poll},
80        time::Duration,
81    };
82    use tokio::time::sleep;
83    use tower_service::Service;
84
85    struct SlowService(Duration);
86
87    impl Service<()> for SlowService {
88        type Response = ();
89        type Error = Infallible;
90        type Future = Pin<Box<dyn Future<Output = Result<(), Infallible>> + Send>>;
91
92        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93            Poll::Ready(Ok(()))
94        }
95
96        fn call(&mut self, _req: ()) -> Self::Future {
97            let delay = self.0;
98            Box::pin(async move {
99                sleep(delay).await;
100                Ok(())
101            })
102        }
103    }
104
105    struct FastService;
106
107    impl Service<()> for FastService {
108        type Response = &'static str;
109        type Error = Infallible;
110        type Future = Pin<Box<dyn Future<Output = Result<&'static str, Infallible>> + Send>>;
111
112        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113            Poll::Ready(Ok(()))
114        }
115
116        fn call(&mut self, _req: ()) -> Self::Future {
117            Box::pin(async move { Ok("ok") })
118        }
119    }
120
121    #[tokio::test(start_paused = true)]
122    async fn elapsed_error_when_timeout_exceeded() {
123        let mut svc = Timeout::new(SlowService(Duration::from_secs(10)), Duration::from_secs(1));
124
125        let res = svc.call(()).await;
126        assert!(res.is_err());
127
128        let err = res.unwrap_err();
129        assert!(err.downcast_ref::<error::Elapsed>().is_some());
130    }
131
132    #[tokio::test(start_paused = true)]
133    async fn response_passes_through_when_under_timeout() {
134        let mut svc = Timeout::new(FastService, Duration::from_secs(1));
135
136        let res = svc.call(()).await;
137        assert_eq!(res.unwrap(), "ok");
138    }
139}