Skip to content

Commit e57bf2a

Browse files
fix: allow for recursive block-on calls
1 parent 6d2a43e commit e57bf2a

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

Cargo.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,7 @@ required-features = ["unstable"]
9898

9999
[[example]]
100100
name = "surf-web"
101-
required-features = ["surf"]
101+
required-features = ["surf"]
102+
103+
[patch.crates-io]
104+
smol = { git = "https://github.com/dignifiedquire/smol-1", branch = "feat/recursive-block-on" }

src/task/builder.rs

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::cell::Cell;
12
use std::future::Future;
23
use std::pin::Pin;
34
use std::sync::Arc;
@@ -8,6 +9,11 @@ use pin_project_lite::pin_project;
89
use crate::io;
910
use crate::task::{JoinHandle, Task, TaskLocalsWrapper};
1011

12+
#[cfg(not(target_os = "unknown"))]
13+
thread_local! {
14+
static IS_CURRENT_BLOCKING: Cell<usize> = Cell::new(0);
15+
}
16+
1117
/// Task builder that configures the settings of a new task.
1218
#[derive(Debug, Default)]
1319
pub struct Builder {
@@ -151,7 +157,36 @@ impl Builder {
151157
});
152158

153159
// Run the future as a task.
154-
unsafe { TaskLocalsWrapper::set_current(&wrapped.tag, || smol::run(wrapped)) }
160+
IS_CURRENT_BLOCKING.with(|is_current_blocking| {
161+
let count = is_current_blocking.get();
162+
let res = if count == 0 {
163+
// increase the count
164+
is_current_blocking.replace(1);
165+
166+
// The first call should use run.
167+
unsafe {
168+
TaskLocalsWrapper::set_current(&wrapped.tag, || {
169+
let res = smol::run(wrapped);
170+
is_current_blocking.replace(is_current_blocking.get() - 1);
171+
res
172+
})
173+
}
174+
} else {
175+
// increase the count
176+
is_current_blocking.replace(is_current_blocking.get() + 1);
177+
178+
// Subsequent calls should be using `block_on`.
179+
unsafe {
180+
TaskLocalsWrapper::set_current(&wrapped.tag, || {
181+
let res = smol::block_on(wrapped);
182+
is_current_blocking.replace(is_current_blocking.get() - 1);
183+
res
184+
})
185+
}
186+
};
187+
188+
res
189+
})
155190
}
156191
}
157192

tests/block_on.rs

+59
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,62 @@ fn panic() {
1616
panic!("boom");
1717
});
1818
}
19+
20+
#[cfg(feature = "unstable")]
21+
#[test]
22+
fn nested_block_on_local() {
23+
let x = task::block_on(async {
24+
let a =
25+
task::block_on(async { task::block_on(async { async_std::future::ready(3).await }) });
26+
let b = task::spawn_local(async {
27+
task::block_on(async { async_std::future::ready(2).await })
28+
})
29+
.await;
30+
let c =
31+
task::block_on(async { task::block_on(async { async_std::future::ready(1).await }) });
32+
a + b + c
33+
});
34+
35+
assert_eq!(x, 3 + 2 + 1);
36+
37+
let y = task::block_on(async {
38+
let a =
39+
task::block_on(async { task::block_on(async { async_std::future::ready(3).await }) });
40+
let b = task::spawn_local(async {
41+
task::block_on(async { async_std::future::ready(2).await })
42+
})
43+
.await;
44+
let c =
45+
task::block_on(async { task::block_on(async { async_std::future::ready(1).await }) });
46+
a + b + c
47+
});
48+
49+
assert_eq!(y, 3 + 2 + 1);
50+
}
51+
52+
#[test]
53+
fn nested_block_on() {
54+
let x = task::block_on(async {
55+
let a =
56+
task::block_on(async { task::block_on(async { async_std::future::ready(3).await }) });
57+
let b =
58+
task::block_on(async { task::block_on(async { async_std::future::ready(2).await }) });
59+
let c =
60+
task::block_on(async { task::block_on(async { async_std::future::ready(1).await }) });
61+
a + b + c
62+
});
63+
64+
assert_eq!(x, 3 + 2 + 1);
65+
66+
let y = task::block_on(async {
67+
let a =
68+
task::block_on(async { task::block_on(async { async_std::future::ready(3).await }) });
69+
let b =
70+
task::block_on(async { task::block_on(async { async_std::future::ready(2).await }) });
71+
let c =
72+
task::block_on(async { task::block_on(async { async_std::future::ready(1).await }) });
73+
a + b + c
74+
});
75+
76+
assert_eq!(y, 3 + 2 + 1);
77+
}

0 commit comments

Comments
 (0)