Skip to content

Commit cd15d7b

Browse files
Prevent _legacy_load with weights_only=True (#144993)
Prevent _legacy_load with weights_only=True (#144914) Pull Request resolved: #144914 Approved by: https://github.com/malfet, https://github.com/albanD (cherry picked from commit 7c3aa1d) Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
1 parent a2639bc commit cd15d7b

File tree

3 files changed

+48
-29
lines changed

3 files changed

+48
-29
lines changed

test/quantization/bc/test_backward_compatibility.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ def _test_op(
112112
torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
113113
torch.save(qmodule(input_tensor), expected_file)
114114

115-
input_tensor = torch.load(input_file)
115+
# weights_only=False as file was saved in .tar format
116+
input_tensor = torch.load(input_file, weights_only=False)
116117
# weights_only = False as sometimes get ScriptObject here
117118
qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
118119
qmodule_scripted = torch.jit.load(scripted_module_file)
119120
qmodule_traced = torch.jit.load(traced_module_file)
120-
expected = torch.load(expected_file)
121+
# weights_only=False as file was saved in .tar format
122+
expected = torch.load(expected_file, weights_only=False)
121123
self.assertEqual(qmodule(input_tensor), expected, atol=prec)
122124
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
123125
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)

test/test_serialization.py

+40-18
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,6 @@ def _test_serialization(self, weights_only):
224224
def test_serialization(self):
225225
self._test_serialization(False)
226226

227-
def test_serialization_safe(self):
228-
self._test_serialization(True)
229-
230227
def test_serialization_filelike(self):
231228
# Test serialization (load and save) with a filelike object
232229
b = self._test_serialization_data()
@@ -362,9 +359,6 @@ def _test_serialization(conversion):
362359
def test_serialization_sparse(self):
363360
self._test_serialization(False)
364361

365-
def test_serialization_sparse_safe(self):
366-
self._test_serialization(True)
367-
368362
def test_serialization_sparse_invalid(self):
369363
x = torch.zeros(3, 3)
370364
x[1][1] = 1
@@ -510,9 +504,6 @@ def __reduce__(self):
510504
def test_serialization_backwards_compat(self):
511505
self._test_serialization_backwards_compat(False)
512506

513-
def test_serialization_backwards_compat_safe(self):
514-
self._test_serialization_backwards_compat(True)
515-
516507
def test_serialization_save_warnings(self):
517508
with warnings.catch_warnings(record=True) as warns:
518509
with tempfile.NamedTemporaryFile() as checkpoint:
@@ -557,7 +548,8 @@ def load_bytes():
557548
def check_map_locations(map_locations, dtype, intended_device):
558549
for fileobject_lambda in fileobject_lambdas:
559550
for map_location in map_locations:
560-
tensor = torch.load(fileobject_lambda(), map_location=map_location)
551+
# weigts_only=False as the downloaded file path uses the old serialization format
552+
tensor = torch.load(fileobject_lambda(), map_location=map_location, weights_only=False)
561553

562554
self.assertEqual(tensor.device, intended_device)
563555
self.assertEqual(tensor.dtype, dtype)
@@ -600,7 +592,8 @@ def test_load_nonexistent_device(self):
600592

601593
error_msg = r'Attempting to deserialize object on a CUDA device'
602594
with self.assertRaisesRegex(RuntimeError, error_msg):
603-
_ = torch.load(buf)
595+
# weights_only=False as serialized is in legacy format
596+
_ = torch.load(buf, weights_only=False)
604597

605598
@unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681")
606599
def test_serialization_filelike_api_requirements(self):
@@ -720,7 +713,8 @@ def test_serialization_storage_slice(self):
720713
b'\x00\x00\x00\x00')
721714

722715
buf = io.BytesIO(serialized)
723-
(s1, s2) = torch.load(buf)
716+
# serialized was saved with PyTorch 0.3.1
717+
(s1, s2) = torch.load(buf, weights_only=False)
724718
self.assertEqual(s1[0], 0)
725719
self.assertEqual(s2[0], 0)
726720
self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
@@ -837,6 +831,24 @@ def wrapper(*args, **kwargs):
837831
def __exit__(self, *args, **kwargs):
838832
torch.save = self.torch_save
839833

834+
835+
# used to set weights_only=False in _use_new_zipfile_serialization=False tests
836+
class load_method:
837+
def __init__(self, weights_only):
838+
self.weights_only = weights_only
839+
self.torch_load = torch.load
840+
841+
def __enter__(self, *args, **kwargs):
842+
def wrapper(*args, **kwargs):
843+
kwargs['weights_only'] = self.weights_only
844+
return self.torch_load(*args, **kwargs)
845+
846+
torch.load = wrapper
847+
848+
def __exit__(self, *args, **kwargs):
849+
torch.load = self.torch_load
850+
851+
840852
Point = namedtuple('Point', ['x', 'y'])
841853

842854
class ClassThatUsesBuildInstruction:
@@ -873,14 +885,25 @@ def test(f_new, f_old):
873885

874886
torch.save(x, f_old, _use_new_zipfile_serialization=False)
875887
f_old.seek(0)
876-
x_old_load = torch.load(f_old, weights_only=weights_only)
888+
x_old_load = torch.load(f_old, weights_only=False)
877889
self.assertEqual(x_old_load, x_new_load)
878890

879891
with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w:
880892
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
881893
test(f_new, f_old)
882894
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
883895

896+
def test_old_serialization_fails_with_weights_only(self):
897+
a = torch.randn(5, 5)
898+
with BytesIOContext() as f:
899+
torch.save(a, f, _use_new_zipfile_serialization=False)
900+
f.seek(0)
901+
with self.assertRaisesRegex(
902+
RuntimeError,
903+
"Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
904+
):
905+
torch.load(f, weights_only=True)
906+
884907

885908
class TestOldSerialization(TestCase, SerializationMixin):
886909
# unique_key is necessary because on Python 2.7, if a warning passed to
@@ -956,8 +979,7 @@ def test_serialization_offset(self):
956979
self.assertEqual(i, i_loaded)
957980
self.assertEqual(j, j_loaded)
958981

959-
@parametrize('weights_only', (True, False))
960-
def test_serialization_offset_filelike(self, weights_only):
982+
def test_serialization_offset_filelike(self):
961983
a = torch.randn(5, 5)
962984
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
963985
i, j = 41, 43
@@ -969,16 +991,16 @@ def test_serialization_offset_filelike(self, weights_only):
969991
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
970992
f.seek(0)
971993
i_loaded = pickle.load(f)
972-
a_loaded = torch.load(f, weights_only=weights_only)
994+
a_loaded = torch.load(f)
973995
j_loaded = pickle.load(f)
974-
b_loaded = torch.load(f, weights_only=weights_only)
996+
b_loaded = torch.load(f)
975997
self.assertTrue(torch.equal(a, a_loaded))
976998
self.assertTrue(torch.equal(b, b_loaded))
977999
self.assertEqual(i, i_loaded)
9781000
self.assertEqual(j, j_loaded)
9791001

9801002
def run(self, *args, **kwargs):
981-
with serialization_method(use_zip=False):
1003+
with serialization_method(use_zip=False), load_method(weights_only=False):
9821004
return super().run(*args, **kwargs)
9831005

9841006

torch/serialization.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1482,15 +1482,10 @@ def _get_wo_message(message: str) -> str:
14821482
"please torch.save your checkpoint with this option in order to use mmap."
14831483
)
14841484
if weights_only:
1485-
try:
1486-
return _legacy_load(
1487-
opened_file,
1488-
map_location,
1489-
_weights_only_unpickler,
1490-
**pickle_load_args,
1491-
)
1492-
except pickle.UnpicklingError as e:
1493-
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1485+
raise RuntimeError(
1486+
"Cannot use ``weights_only=True`` with files saved in the "
1487+
".tar format used before version 1.6. " + UNSAFE_MESSAGE
1488+
)
14941489
return _legacy_load(
14951490
opened_file, map_location, pickle_module, **pickle_load_args
14961491
)

0 commit comments

Comments
 (0)