aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSteve Dower <steve.dower@microsoft.com>2019-03-29 16:37:16 -0700
committerGitHub <noreply@github.com>2019-03-29 16:37:16 -0700
commit2438cdf0e932a341c7613bf4323d06b91ae9f1f1 (patch)
tree231cdf3f22e1d5eb9f88fe7a511ab47e3cf8d225 /Lib/ctypes
parentbpo-35947: Update Windows to the current version of libffi (GH-11797) (diff)
downloadcpython-2438cdf0e932a341c7613bf4323d06b91ae9f1f1.tar.gz
cpython-2438cdf0e932a341c7613bf4323d06b91ae9f1f1.tar.bz2
cpython-2438cdf0e932a341c7613bf4323d06b91ae9f1f1.zip
bpo-36085: Enable better DLL resolution on Windows (GH-12302)
Diffstat (limited to 'Lib/ctypes')
-rw-r--r--Lib/ctypes/__init__.py12
-rw-r--r--Lib/ctypes/test/test_loading.py63
2 files changed, 74 insertions, 1 deletions
diff --git a/Lib/ctypes/__init__.py b/Lib/ctypes/__init__.py
index 5f78beda586..4107db3e397 100644
--- a/Lib/ctypes/__init__.py
+++ b/Lib/ctypes/__init__.py
@@ -326,7 +326,8 @@ class CDLL(object):
def __init__(self, name, mode=DEFAULT_MODE, handle=None,
use_errno=False,
- use_last_error=False):
+ use_last_error=False,
+ winmode=None):
self._name = name
flags = self._func_flags_
if use_errno:
@@ -341,6 +342,15 @@ class CDLL(object):
"""
if name and name.endswith(")") and ".a(" in name:
mode |= ( _os.RTLD_MEMBER | _os.RTLD_NOW )
+ if _os.name == "nt":
+ if winmode is not None:
+ mode = winmode
+ else:
+ import nt
+ mode = nt._LOAD_LIBRARY_SEARCH_DEFAULT_DIRS
+ if '/' in name or '\\' in name:
+ self._name = nt._getfullpathname(self._name)
+ mode |= nt._LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
class _FuncPtr(_CFuncPtr):
_flags_ = flags
diff --git a/Lib/ctypes/test/test_loading.py b/Lib/ctypes/test/test_loading.py
index f3b65b9d6e7..be367c6fa35 100644
--- a/Lib/ctypes/test/test_loading.py
+++ b/Lib/ctypes/test/test_loading.py
@@ -1,6 +1,9 @@
from ctypes import *
import os
+import shutil
+import subprocess
import sys
+import sysconfig
import unittest
import test.support
from ctypes.util import find_library
@@ -112,5 +115,65 @@ class LoaderTest(unittest.TestCase):
# This is the real test: call the function via 'call_function'
self.assertEqual(0, call_function(proc, (None,)))
+ @unittest.skipUnless(os.name == "nt",
+ 'test specific to Windows')
+ def test_load_dll_with_flags(self):
+ _sqlite3 = test.support.import_module("_sqlite3")
+ src = _sqlite3.__file__
+ if src.lower().endswith("_d.pyd"):
+ ext = "_d.dll"
+ else:
+ ext = ".dll"
+
+ with test.support.temp_dir() as tmp:
+ # We copy two files and load _sqlite3.dll (formerly .pyd),
+ # which has a dependency on sqlite3.dll. Then we test
+ # loading it in subprocesses to avoid it starting in memory
+ # for each test.
+ target = os.path.join(tmp, "_sqlite3.dll")
+ shutil.copy(src, target)
+ shutil.copy(os.path.join(os.path.dirname(src), "sqlite3" + ext),
+ os.path.join(tmp, "sqlite3" + ext))
+
+ def should_pass(command):
+ with self.subTest(command):
+ subprocess.check_output(
+ [sys.executable, "-c",
+ "from ctypes import *; import nt;" + command],
+ cwd=tmp
+ )
+
+ def should_fail(command):
+ with self.subTest(command):
+ with self.assertRaises(subprocess.CalledProcessError):
+ subprocess.check_output(
+ [sys.executable, "-c",
+ "from ctypes import *; import nt;" + command],
+ cwd=tmp, stderr=subprocess.STDOUT,
+ )
+
+ # Default load should not find this in CWD
+ should_fail("WinDLL('_sqlite3.dll')")
+
+ # Relative path (but not just filename) should succeed
+ should_pass("WinDLL('./_sqlite3.dll')")
+
+ # Insecure load flags should succeed
+ should_pass("WinDLL('_sqlite3.dll', winmode=0)")
+
+ # Full path load without DLL_LOAD_DIR shouldn't find dependency
+ should_fail("WinDLL(nt._getfullpathname('_sqlite3.dll'), " +
+ "winmode=nt._LOAD_LIBRARY_SEARCH_SYSTEM32)")
+
+ # Full path load with DLL_LOAD_DIR should succeed
+ should_pass("WinDLL(nt._getfullpathname('_sqlite3.dll'), " +
+ "winmode=nt._LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)")
+
+ # User-specified directory should succeed
+ should_pass("import os; p = os.add_dll_directory(os.getcwd());" +
+ "WinDLL('_sqlite3.dll'); p.close()")
+
+
+
if __name__ == "__main__":
unittest.main()