@@ -224,9 +224,6 @@ def _test_serialization(self, weights_only):
224
224
def test_serialization (self ):
225
225
self ._test_serialization (False )
226
226
227
- def test_serialization_safe (self ):
228
- self ._test_serialization (True )
229
-
230
227
def test_serialization_filelike (self ):
231
228
# Test serialization (load and save) with a filelike object
232
229
b = self ._test_serialization_data ()
@@ -362,9 +359,6 @@ def _test_serialization(conversion):
362
359
def test_serialization_sparse (self ):
363
360
self ._test_serialization (False )
364
361
365
- def test_serialization_sparse_safe (self ):
366
- self ._test_serialization (True )
367
-
368
362
def test_serialization_sparse_invalid (self ):
369
363
x = torch .zeros (3 , 3 )
370
364
x [1 ][1 ] = 1
@@ -510,9 +504,6 @@ def __reduce__(self):
510
504
def test_serialization_backwards_compat (self ):
511
505
self ._test_serialization_backwards_compat (False )
512
506
513
- def test_serialization_backwards_compat_safe (self ):
514
- self ._test_serialization_backwards_compat (True )
515
-
516
507
def test_serialization_save_warnings (self ):
517
508
with warnings .catch_warnings (record = True ) as warns :
518
509
with tempfile .NamedTemporaryFile () as checkpoint :
@@ -557,7 +548,8 @@ def load_bytes():
557
548
def check_map_locations (map_locations , dtype , intended_device ):
558
549
for fileobject_lambda in fileobject_lambdas :
559
550
for map_location in map_locations :
560
- tensor = torch .load (fileobject_lambda (), map_location = map_location )
551
+ # weigts_only=False as the downloaded file path uses the old serialization format
552
+ tensor = torch .load (fileobject_lambda (), map_location = map_location , weights_only = False )
561
553
562
554
self .assertEqual (tensor .device , intended_device )
563
555
self .assertEqual (tensor .dtype , dtype )
@@ -600,7 +592,8 @@ def test_load_nonexistent_device(self):
600
592
601
593
error_msg = r'Attempting to deserialize object on a CUDA device'
602
594
with self .assertRaisesRegex (RuntimeError , error_msg ):
603
- _ = torch .load (buf )
595
+ # weights_only=False as serialized is in legacy format
596
+ _ = torch .load (buf , weights_only = False )
604
597
605
598
@unittest .skipIf ((3 , 8 , 0 ) <= sys .version_info < (3 , 8 , 2 ), "See https://bugs.python.org/issue39681" )
606
599
def test_serialization_filelike_api_requirements (self ):
@@ -720,7 +713,8 @@ def test_serialization_storage_slice(self):
720
713
b'\x00 \x00 \x00 \x00 ' )
721
714
722
715
buf = io .BytesIO (serialized )
723
- (s1 , s2 ) = torch .load (buf )
716
+ # serialized was saved with PyTorch 0.3.1
717
+ (s1 , s2 ) = torch .load (buf , weights_only = False )
724
718
self .assertEqual (s1 [0 ], 0 )
725
719
self .assertEqual (s2 [0 ], 0 )
726
720
self .assertEqual (s1 .data_ptr () + 4 , s2 .data_ptr ())
@@ -837,6 +831,24 @@ def wrapper(*args, **kwargs):
837
831
def __exit__ (self , * args , ** kwargs ):
838
832
torch .save = self .torch_save
839
833
834
+
835
+ # used to set weights_only=False in _use_new_zipfile_serialization=False tests
836
+ class load_method :
837
+ def __init__ (self , weights_only ):
838
+ self .weights_only = weights_only
839
+ self .torch_load = torch .load
840
+
841
+ def __enter__ (self , * args , ** kwargs ):
842
+ def wrapper (* args , ** kwargs ):
843
+ kwargs ['weights_only' ] = self .weights_only
844
+ return self .torch_load (* args , ** kwargs )
845
+
846
+ torch .load = wrapper
847
+
848
+ def __exit__ (self , * args , ** kwargs ):
849
+ torch .load = self .torch_load
850
+
851
+
840
852
Point = namedtuple ('Point' , ['x' , 'y' ])
841
853
842
854
class ClassThatUsesBuildInstruction :
@@ -873,14 +885,25 @@ def test(f_new, f_old):
873
885
874
886
torch .save (x , f_old , _use_new_zipfile_serialization = False )
875
887
f_old .seek (0 )
876
- x_old_load = torch .load (f_old , weights_only = weights_only )
888
+ x_old_load = torch .load (f_old , weights_only = False )
877
889
self .assertEqual (x_old_load , x_new_load )
878
890
879
891
with AlwaysWarnTypedStorageRemoval (True ), warnings .catch_warnings (record = True ) as w :
880
892
with tempfile .NamedTemporaryFile () as f_new , tempfile .NamedTemporaryFile () as f_old :
881
893
test (f_new , f_old )
882
894
self .assertTrue (len (w ) == 0 , msg = f"Expected no warnings but got { [str (x ) for x in w ]} " )
883
895
896
+ def test_old_serialization_fails_with_weights_only (self ):
897
+ a = torch .randn (5 , 5 )
898
+ with BytesIOContext () as f :
899
+ torch .save (a , f , _use_new_zipfile_serialization = False )
900
+ f .seek (0 )
901
+ with self .assertRaisesRegex (
902
+ RuntimeError ,
903
+ "Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
904
+ ):
905
+ torch .load (f , weights_only = True )
906
+
884
907
885
908
class TestOldSerialization (TestCase , SerializationMixin ):
886
909
# unique_key is necessary because on Python 2.7, if a warning passed to
@@ -956,8 +979,7 @@ def test_serialization_offset(self):
956
979
self .assertEqual (i , i_loaded )
957
980
self .assertEqual (j , j_loaded )
958
981
959
- @parametrize ('weights_only' , (True , False ))
960
- def test_serialization_offset_filelike (self , weights_only ):
982
+ def test_serialization_offset_filelike (self ):
961
983
a = torch .randn (5 , 5 )
962
984
b = torch .randn (1024 , 1024 , 512 , dtype = torch .float32 )
963
985
i , j = 41 , 43
@@ -969,16 +991,16 @@ def test_serialization_offset_filelike(self, weights_only):
969
991
self .assertTrue (f .tell () > 2 * 1024 * 1024 * 1024 )
970
992
f .seek (0 )
971
993
i_loaded = pickle .load (f )
972
- a_loaded = torch .load (f , weights_only = weights_only )
994
+ a_loaded = torch .load (f )
973
995
j_loaded = pickle .load (f )
974
- b_loaded = torch .load (f , weights_only = weights_only )
996
+ b_loaded = torch .load (f )
975
997
self .assertTrue (torch .equal (a , a_loaded ))
976
998
self .assertTrue (torch .equal (b , b_loaded ))
977
999
self .assertEqual (i , i_loaded )
978
1000
self .assertEqual (j , j_loaded )
979
1001
980
1002
def run (self , * args , ** kwargs ):
981
- with serialization_method (use_zip = False ):
1003
+ with serialization_method (use_zip = False ), load_method ( weights_only = False ) :
982
1004
return super ().run (* args , ** kwargs )
983
1005
984
1006
0 commit comments