From b0c65c677f5298df8653df1e382b406bea420ba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergey=20M=E2=80=A4?= Date: Sat, 17 Dec 2016 18:44:53 +0700 Subject: [PATCH] [utils] Improve urljoin --- test/test_utils.py | 3 +++ youtube_dl/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3f45b0bd1..1cdac82fc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -448,11 +448,14 @@ class TestUtil(unittest.TestCase): def test_urljoin(self): self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('//foo.de/', '/a/b/c.txt'), '//foo.de/a/b/c.txt') self.assertEqual(urljoin('http://foo.de/', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt') self.assertEqual(urljoin('http://foo.de', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt') self.assertEqual(urljoin('http://foo.de', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt') self.assertEqual(urljoin('http://foo.de/', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('http://foo.de/', '//foo.de/a/b/c.txt'), '//foo.de/a/b/c.txt') self.assertEqual(urljoin(None, 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin(None, '//foo.de/a/b/c.txt'), '//foo.de/a/b/c.txt') self.assertEqual(urljoin('', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') self.assertEqual(urljoin(['foobar'], 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') self.assertEqual(urljoin('http://foo.de/', None), None) diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 694e9a600..528d87bb9 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -1703,9 +1703,9 @@ def base_url(url): def urljoin(base, path): if not isinstance(path, compat_str) or not path: return None - if re.match(r'https?://', path): + if re.match(r'^(?:https?:)?//', path): return path - if not isinstance(base, compat_str) or not re.match(r'https?://', base): + if not isinstance(base, compat_str) or not re.match(r'^(?:https?:)?//', base): return None return compat_urlparse.urljoin(base, path)