1pub mod error;
7pub mod future;
8mod layer;
9
10pub use self::layer::TimeoutLayer;
11
12use self::future::ResponseFuture;
13use std::task::{Context, Poll};
14use std::time::Duration;
15use tower_service::Service;
16
17#[derive(Debug, Clone)]
19pub struct Timeout<T> {
20 inner: T,
21 timeout: Duration,
22}
23
24impl<T> Timeout<T> {
27 pub const fn new(inner: T, timeout: Duration) -> Self {
29 Timeout { inner, timeout }
30 }
31
32 pub fn get_ref(&self) -> &T {
34 &self.inner
35 }
36
37 pub fn get_mut(&mut self) -> &mut T {
39 &mut self.inner
40 }
41
42 pub fn into_inner(self) -> T {
44 self.inner
45 }
46}
47
48impl<S, Request> Service<Request> for Timeout<S>
49where
50 S: Service<Request>,
51 S::Error: Into<crate::BoxError>,
52{
53 type Response = S::Response;
54 type Error = crate::BoxError;
55 type Future = ResponseFuture<S::Future>;
56
57 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58 match self.inner.poll_ready(cx) {
59 Poll::Pending => Poll::Pending,
60 Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
61 }
62 }
63
64 fn call(&mut self, request: Request) -> Self::Future {
65 let response = self.inner.call(request);
66 let sleep = tokio::time::sleep(self.timeout);
67
68 ResponseFuture::new(response, sleep)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use std::{
76 convert::Infallible,
77 future::Future,
78 pin::Pin,
79 task::{Context, Poll},
80 time::Duration,
81 };
82 use tokio::time::sleep;
83 use tower_service::Service;
84
85 struct SlowService(Duration);
86
87 impl Service<()> for SlowService {
88 type Response = ();
89 type Error = Infallible;
90 type Future = Pin<Box<dyn Future<Output = Result<(), Infallible>> + Send>>;
91
92 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93 Poll::Ready(Ok(()))
94 }
95
96 fn call(&mut self, _req: ()) -> Self::Future {
97 let delay = self.0;
98 Box::pin(async move {
99 sleep(delay).await;
100 Ok(())
101 })
102 }
103 }
104
105 struct FastService;
106
107 impl Service<()> for FastService {
108 type Response = &'static str;
109 type Error = Infallible;
110 type Future = Pin<Box<dyn Future<Output = Result<&'static str, Infallible>> + Send>>;
111
112 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113 Poll::Ready(Ok(()))
114 }
115
116 fn call(&mut self, _req: ()) -> Self::Future {
117 Box::pin(async move { Ok("ok") })
118 }
119 }
120
121 #[tokio::test(start_paused = true)]
122 async fn elapsed_error_when_timeout_exceeded() {
123 let mut svc = Timeout::new(SlowService(Duration::from_secs(10)), Duration::from_secs(1));
124
125 let res = svc.call(()).await;
126 assert!(res.is_err());
127
128 let err = res.unwrap_err();
129 assert!(err.downcast_ref::<error::Elapsed>().is_some());
130 }
131
132 #[tokio::test(start_paused = true)]
133 async fn response_passes_through_when_under_timeout() {
134 let mut svc = Timeout::new(FastService, Duration::from_secs(1));
135
136 let res = svc.call(()).await;
137 assert_eq!(res.unwrap(), "ok");
138 }
139}