Skip to content

Commit 1057281

Browse files
authored
Merge pull request #156 from akadan47/fix-hls-compatibility-security
Content types for HLS playlist & segments.
2 parents ad709e0 + 81dcc8e commit 1057281

File tree

1 file changed

+152
-52
lines changed

1 file changed

+152
-52
lines changed

protocol/hls/src/server.rs

Lines changed: 152 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,81 +14,135 @@ use {
1414

1515
type GenericError = Box<dyn std::error::Error + Send + Sync>;
1616
type Result<T> = std::result::Result<T, GenericError>;
17+
1718
static NOTFOUND: &[u8] = b"Not Found";
1819
static UNAUTHORIZED: &[u8] = b"Unauthorized";
1920

20-
async fn handle_connection(State(auth): State<Option<Auth>>, req: Request<Body>) -> Response<Body> {
21-
let path = req.uri().path();
21+
#[derive(Debug)]
22+
enum HlsFileType {
23+
Playlist,
24+
Segment,
25+
}
2226

23-
let query_string: Option<String> = req.uri().query().map(|s| s.to_string());
24-
let mut file_path: String = String::from("");
25-
26-
if path.ends_with(".m3u8") {
27-
//http://127.0.0.1/app_name/stream_name/stream_name.m3u8
28-
let m3u8_index = path.find(".m3u8").unwrap();
29-
30-
if m3u8_index > 0 {
31-
let (left, _) = path.split_at(m3u8_index);
32-
let rv: Vec<_> = left.split('/').collect();
33-
34-
let app_name = String::from(rv[1]);
35-
let stream_name = String::from(rv[2]);
36-
37-
if let Some(auth_val) = auth {
38-
if auth_val
39-
.authenticate(
40-
&stream_name,
41-
&query_string.map(SecretCarrier::Query),
42-
true,
43-
)
44-
.is_err()
45-
{
46-
return Response::builder()
47-
.status(StatusCode::UNAUTHORIZED)
48-
.body(UNAUTHORIZED.into())
49-
.unwrap();
50-
}
51-
}
52-
53-
file_path = format!("./{app_name}/{stream_name}/{stream_name}.m3u8");
27+
impl HlsFileType {
28+
const CONTENT_TYPE_PLAYLIST: &'static str = "application/vnd.apple.mpegurl";
29+
const CONTENT_TYPE_SEGMENT: &'static str = "video/mp2t";
30+
31+
fn content_type(&self) -> &str {
32+
match self {
33+
Self::Playlist => Self::CONTENT_TYPE_PLAYLIST,
34+
Self::Segment => Self::CONTENT_TYPE_SEGMENT,
5435
}
55-
} else if path.ends_with(".ts") {
56-
//http://127.0.0.1/app_name/stream_name/ts_name.m3u8
57-
let ts_index = path.find(".ts").unwrap();
36+
}
37+
}
5838

59-
if ts_index > 0 {
60-
let (left, _) = path.split_at(ts_index);
39+
#[derive(Debug)]
40+
struct HlsPath {
41+
app_name: String,
42+
stream_name: String,
43+
file_name: String,
44+
file_type: HlsFileType,
45+
}
6146

62-
let rv: Vec<_> = left.split('/').collect();
47+
impl HlsPath {
48+
const M3U8_EXT: &'static str = "m3u8";
49+
const TS_EXT: &'static str = "ts";
6350

64-
let app_name = String::from(rv[1]);
65-
let stream_name = String::from(rv[2]);
66-
let ts_name = String::from(rv[3]);
51+
fn parse(path: &str) -> Option<Self> {
52+
if path.is_empty() || path.contains("..") {
53+
return None;
54+
}
55+
56+
let mut parts = path[1..].split('/');
57+
let app_name = parts.next()?;
58+
let stream_name = parts.next()?;
59+
let file_part = parts.next()?;
60+
if parts.next().is_some() {
61+
return None;
62+
}
6763

68-
file_path = format!("./{app_name}/{stream_name}/{ts_name}.ts");
64+
let (file_name, ext) = file_part.rsplit_once('.')?;
65+
if file_name.is_empty() {
66+
return None;
6967
}
68+
69+
let file_type = match ext {
70+
Self::M3U8_EXT => HlsFileType::Playlist,
71+
Self::TS_EXT => HlsFileType::Segment,
72+
_ => return None,
73+
};
74+
75+
Some(Self {
76+
app_name: app_name.into(),
77+
stream_name: stream_name.into(),
78+
file_name: file_name.into(),
79+
file_type,
80+
})
81+
}
82+
83+
fn to_file_path(&self) -> String {
84+
let ext = match self.file_type {
85+
HlsFileType::Playlist => Self::M3U8_EXT,
86+
HlsFileType::Segment => Self::TS_EXT,
87+
};
88+
format!(
89+
"./{}/{}/{}.{}",
90+
self.app_name, self.stream_name, self.file_name, ext
91+
)
7092
}
71-
simple_file_send(file_path.as_str()).await
7293
}
7394

74-
/// HTTP status code 404
75-
fn not_found() -> Response<Body> {
95+
fn response_unauthorized() -> Response<Body> {
96+
Response::builder()
97+
.status(StatusCode::UNAUTHORIZED)
98+
.body(UNAUTHORIZED.into())
99+
.unwrap()
100+
}
101+
102+
fn response_not_found() -> Response<Body> {
76103
Response::builder()
77104
.status(StatusCode::NOT_FOUND)
78105
.body(NOTFOUND.into())
79106
.unwrap()
80107
}
81108

82-
async fn simple_file_send(filename: &str) -> Response<Body> {
83-
// Serve a file by asynchronously reading it by chunks using tokio-util crate.
109+
async fn response_file(hls_path: &HlsPath) -> Response<Body> {
110+
let file_path = hls_path.to_file_path();
111+
112+
if let Ok(file) = File::open(&file_path).await {
113+
let builder = Response::builder().header("Content-Type", hls_path.file_type.content_type());
84114

85-
if let Ok(file) = File::open(filename).await {
115+
// Serve a file by asynchronously reading it by chunks using tokio-util crate.
86116
let stream = FramedRead::new(file, BytesCodec::new());
87-
let body = Body::from_stream(stream);
88-
return Response::new(body);
117+
return builder.body(Body::from_stream(stream)).unwrap();
118+
}
119+
120+
response_not_found()
121+
}
122+
123+
async fn handle_connection(State(auth): State<Option<Auth>>, req: Request<Body>) -> Response<Body> {
124+
let path = req.uri().path();
125+
let query_string = req.uri().query().map(|s| s.to_string());
126+
127+
let hls_path = match HlsPath::parse(path) {
128+
Some(p) => p,
129+
None => return response_not_found(),
130+
};
131+
132+
if let (Some(auth_val), HlsFileType::Playlist) = (auth.as_ref(), &hls_path.file_type) {
133+
if auth_val
134+
.authenticate(
135+
&hls_path.stream_name,
136+
&query_string.map(SecretCarrier::Query),
137+
true,
138+
)
139+
.is_err()
140+
{
141+
return response_unauthorized();
142+
}
89143
}
90144

91-
not_found()
145+
response_file(&hls_path).await
92146
}
93147

94148
pub async fn run(port: usize, auth: Option<Auth>) -> Result<()> {
@@ -105,3 +159,49 @@ pub async fn run(port: usize, auth: Option<Auth>) -> Result<()> {
105159

106160
Ok(())
107161
}
162+
163+
#[cfg(test)]
164+
mod tests {
165+
use super::{HlsFileType, HlsPath};
166+
167+
#[test]
168+
fn test_hls_path_parse() {
169+
// Playlist
170+
let playlist = HlsPath::parse("/live/stream/stream.m3u8").unwrap();
171+
assert_eq!(playlist.app_name, "live");
172+
assert_eq!(playlist.stream_name, "stream");
173+
assert_eq!(playlist.file_name, "stream");
174+
assert!(matches!(playlist.file_type, HlsFileType::Playlist));
175+
assert_eq!(playlist.to_file_path(), "./live/stream/stream.m3u8");
176+
assert_eq!(
177+
playlist.file_type.content_type(),
178+
"application/vnd.apple.mpegurl"
179+
);
180+
181+
// Segment
182+
let segment = HlsPath::parse("/live/stream/123.ts").unwrap();
183+
assert_eq!(segment.app_name, "live");
184+
assert_eq!(segment.stream_name, "stream");
185+
assert_eq!(segment.file_name, "123");
186+
assert!(matches!(segment.file_type, HlsFileType::Segment));
187+
assert_eq!(segment.to_file_path(), "./live/stream/123.ts");
188+
assert_eq!(segment.file_type.content_type(), "video/mp2t");
189+
190+
// Negative
191+
assert!(HlsPath::parse("").is_none());
192+
assert!(HlsPath::parse("/invalid").is_none());
193+
assert!(HlsPath::parse("/too/many/parts/of/path.m3u8").is_none());
194+
assert!(HlsPath::parse("/live/stream/invalid.mp4").is_none());
195+
assert!(HlsPath::parse("/live/stream/../../etc/passwd").is_none());
196+
assert!(HlsPath::parse("/live/stream/...").is_none());
197+
assert!(HlsPath::parse("/live/stream.m3u8").is_none());
198+
assert!(HlsPath::parse("/live/stream.ts").is_none());
199+
assert!(HlsPath::parse("/live/stream/").is_none());
200+
assert!(HlsPath::parse("/live/stream.m3u8").is_none());
201+
assert!(HlsPath::parse("/live/stream.ts").is_none());
202+
assert!(HlsPath::parse("/live/stream/file.").is_none());
203+
assert!(HlsPath::parse("/live/stream/.m3u8").is_none());
204+
assert!(HlsPath::parse("/live/stream/file.M3U8").is_none());
205+
assert!(HlsPath::parse("/live/stream/file.TS").is_none());
206+
}
207+
}

0 commit comments

Comments
 (0)