@@ -27,13 +27,13 @@ cdef class CoreProtocol:
27
27
# type of `scram` is `SCRAMAuthentcation`
28
28
self .scram = None
29
29
30
- # executemany support data
31
- self ._execute_iter = None
32
- self ._execute_portal_name = None
33
- self ._execute_stmt_name = None
34
-
35
30
self ._reset_result()
36
31
32
+ cpdef is_in_transaction(self ):
33
+ # PQTRANS_INTRANS = idle, within transaction block
34
+ # PQTRANS_INERROR = idle, within failed transaction
35
+ return self .xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
36
+
37
37
cdef _read_server_messages(self ):
38
38
cdef:
39
39
char mtype
@@ -258,22 +258,7 @@ cdef class CoreProtocol:
258
258
elif mtype == b' Z' :
259
259
# ReadyForQuery
260
260
self ._parse_msg_ready_for_query()
261
- if self .result_type == RESULT_FAILED:
262
- self ._push_result()
263
- else :
264
- try :
265
- buf = < WriteBuffer> next(self ._execute_iter)
266
- except StopIteration :
267
- self ._push_result()
268
- except Exception as e:
269
- self .result_type = RESULT_FAILED
270
- self .result = e
271
- self ._push_result()
272
- else :
273
- # Next iteration over the executemany() arg sequence
274
- self ._send_bind_message(
275
- self ._execute_portal_name, self ._execute_stmt_name,
276
- buf, 0 )
261
+ self ._push_result()
277
262
278
263
elif mtype == b' I' :
279
264
# EmptyQueryResponse
@@ -775,6 +760,17 @@ cdef class CoreProtocol:
775
760
if self .con_status != CONNECTION_OK:
776
761
raise apg_exc.InternalClientError(' not connected' )
777
762
763
+ cdef WriteBuffer _build_parse_message(self , str stmt_name, str query):
764
+ cdef WriteBuffer buf
765
+
766
+ buf = WriteBuffer.new_message(b' P' )
767
+ buf.write_str(stmt_name, self .encoding)
768
+ buf.write_str(query, self .encoding)
769
+ buf.write_int16(0 )
770
+
771
+ buf.end_message()
772
+ return buf
773
+
778
774
cdef WriteBuffer _build_bind_message(self , str portal_name,
779
775
str stmt_name,
780
776
WriteBuffer bind_data):
@@ -790,6 +786,25 @@ cdef class CoreProtocol:
790
786
buf.end_message()
791
787
return buf
792
788
789
+ cdef WriteBuffer _build_empty_bind_data(self ):
790
+ cdef WriteBuffer buf
791
+ buf = WriteBuffer.new()
792
+ buf.write_int16(0 ) # The number of parameter format codes
793
+ buf.write_int16(0 ) # The number of parameter values
794
+ buf.write_int16(0 ) # The number of result-column format codes
795
+ return buf
796
+
797
+ cdef WriteBuffer _build_execute_message(self , str portal_name,
798
+ int32_t limit):
799
+ cdef WriteBuffer buf
800
+
801
+ buf = WriteBuffer.new_message(b' E' )
802
+ buf.write_str(portal_name, self .encoding) # name of the portal
803
+ buf.write_int32(limit) # number of rows to return; 0 - all
804
+
805
+ buf.end_message()
806
+ return buf
807
+
793
808
# API for subclasses
794
809
795
810
cdef _connect(self ):
@@ -840,12 +855,7 @@ cdef class CoreProtocol:
840
855
self ._ensure_connected()
841
856
self ._set_state(PROTOCOL_PREPARE)
842
857
843
- buf = WriteBuffer.new_message(b' P' )
844
- buf.write_str(stmt_name, self .encoding)
845
- buf.write_str(query, self .encoding)
846
- buf.write_int16(0 )
847
- buf.end_message()
848
- packet = buf
858
+ packet = self ._build_parse_message(stmt_name, query)
849
859
850
860
buf = WriteBuffer.new_message(b' D' )
851
861
buf.write_byte(b' S' )
@@ -867,10 +877,7 @@ cdef class CoreProtocol:
867
877
buf = self ._build_bind_message(portal_name, stmt_name, bind_data)
868
878
packet = buf
869
879
870
- buf = WriteBuffer.new_message(b' E' )
871
- buf.write_str(portal_name, self .encoding) # name of the portal
872
- buf.write_int32(limit) # number of rows to return; 0 - all
873
- buf.end_message()
880
+ buf = self ._build_execute_message(portal_name, limit)
874
881
packet.write_buffer(buf)
875
882
876
883
packet.write_bytes(SYNC_MESSAGE)
@@ -889,30 +896,75 @@ cdef class CoreProtocol:
889
896
890
897
self ._send_bind_message(portal_name, stmt_name, bind_data, limit)
891
898
892
- cdef _bind_execute_many(self , str portal_name, str stmt_name,
893
- object bind_data):
894
-
895
- cdef WriteBuffer buf
896
-
899
+ cdef _execute_many_init(self ):
897
900
self ._ensure_connected()
898
901
self ._set_state(PROTOCOL_BIND_EXECUTE_MANY)
899
902
900
903
self .result = None
901
904
self ._discard_data = True
902
- self ._execute_iter = bind_data
903
- self ._execute_portal_name = portal_name
904
- self ._execute_stmt_name = stmt_name
905
905
906
- try :
907
- buf = < WriteBuffer> next(bind_data)
908
- except StopIteration :
909
- self ._push_result()
910
- except Exception as e:
911
- self .result_type = RESULT_FAILED
912
- self .result = e
906
+ cdef _execute_many_writelines(self , str portal_name, str stmt_name,
907
+ object bind_data):
908
+ cdef:
909
+ WriteBuffer packet
910
+ WriteBuffer buf
911
+ list buffers = []
912
+
913
+ if self .result_type == RESULT_FAILED:
914
+ raise StopIteration (True )
915
+
916
+ while len (buffers) < _EXECUTE_MANY_BUF_NUM:
917
+ packet = WriteBuffer.new()
918
+
919
+ while packet.len() < _EXECUTE_MANY_BUF_SIZE:
920
+ try :
921
+ buf = < WriteBuffer> next(bind_data)
922
+ except StopIteration :
923
+ if packet.len() > 0 :
924
+ buffers.append(packet)
925
+ if len (buffers) > 0 :
926
+ self ._writelines(buffers)
927
+ raise StopIteration (True )
928
+ else :
929
+ raise StopIteration (False )
930
+ except Exception as ex:
931
+ raise StopIteration (ex)
932
+ packet.write_buffer(
933
+ self ._build_bind_message(portal_name, stmt_name, buf))
934
+ packet.write_buffer(
935
+ self ._build_execute_message(portal_name, 0 ))
936
+ buffers.append(packet)
937
+ self ._writelines(buffers)
938
+
939
+ cdef _execute_many_done(self , bint data_sent):
940
+ if data_sent:
941
+ self ._write(SYNC_MESSAGE)
942
+ else :
913
943
self ._push_result()
944
+
945
+ cdef _execute_many_fail(self , object error):
946
+ cdef WriteBuffer buf
947
+
948
+ self .result_type = RESULT_FAILED
949
+ self .result = error
950
+
951
+ # We shall rollback in an implicit transaction to prevent partial
952
+ # commit, while do nothing in an explicit transaction and leaving the
953
+ # error to the user
954
+ if self .is_in_transaction():
955
+ self ._execute_many_done(True )
914
956
else :
915
- self ._send_bind_message(portal_name, stmt_name, buf, 0 )
957
+ # Here if the implicit transaction is in `ignore_till_sync` mode,
958
+ # the `ROLLBACK` will be ignored and `Sync` will restore the state;
959
+ # or else the implicit transaction will be rolled back with a
960
+ # warning saying that there was no transaction, but rollback is
961
+ # done anyway, so we could ignore this warning.
962
+ buf = self ._build_parse_message(' ' , ' ROLLBACK' )
963
+ buf.write_buffer(self ._build_bind_message(
964
+ ' ' , ' ' , self ._build_empty_bind_data()))
965
+ buf.write_buffer(self ._build_execute_message(' ' , 0 ))
966
+ buf.write_bytes(SYNC_MESSAGE)
967
+ self ._write(buf)
916
968
917
969
cdef _execute(self , str portal_name, int32_t limit):
918
970
cdef WriteBuffer buf
@@ -922,10 +974,7 @@ cdef class CoreProtocol:
922
974
923
975
self .result = []
924
976
925
- buf = WriteBuffer.new_message(b' E' )
926
- buf.write_str(portal_name, self .encoding) # name of the portal
927
- buf.write_int32(limit) # number of rows to return; 0 - all
928
- buf.end_message()
977
+ buf = self ._build_execute_message(portal_name, limit)
929
978
930
979
buf.write_bytes(SYNC_MESSAGE)
931
980
@@ -1008,6 +1057,9 @@ cdef class CoreProtocol:
1008
1057
cdef _write(self , buf):
1009
1058
raise NotImplementedError
1010
1059
1060
+ cdef _writelines(self , list buffers):
1061
+ raise NotImplementedError
1062
+
1011
1063
cdef _decode_row(self , const char * buf, ssize_t buf_len):
1012
1064
pass
1013
1065
0 commit comments