@@ -34,6 +34,11 @@ cdef class CoreProtocol:
34
34
35
35
self ._reset_result()
36
36
37
+ cpdef is_in_transaction(self ):
38
+ # PQTRANS_INTRANS = idle, within transaction block
39
+ # PQTRANS_INERROR = idle, within failed transaction
40
+ return self .xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
41
+
37
42
cdef _read_server_messages(self ):
38
43
cdef:
39
44
char mtype
@@ -263,27 +268,16 @@ cdef class CoreProtocol:
263
268
elif mtype == b' Z' :
264
269
# ReadyForQuery
265
270
self ._parse_msg_ready_for_query()
266
- if self .result_type == RESULT_FAILED:
267
- self ._push_result()
268
- else :
269
- try :
270
- buf = < WriteBuffer> next(self ._execute_iter)
271
- except StopIteration :
272
- self ._push_result()
273
- except Exception as e:
274
- self .result_type = RESULT_FAILED
275
- self .result = e
276
- self ._push_result()
277
- else :
278
- # Next iteration over the executemany() arg sequence
279
- self ._send_bind_message(
280
- self ._execute_portal_name, self ._execute_stmt_name,
281
- buf, 0 )
271
+ self ._push_result()
282
272
283
273
elif mtype == b' I' :
284
274
# EmptyQueryResponse
285
275
self .buffer.discard_message()
286
276
277
+ elif mtype == b' 1' :
278
+ # ParseComplete
279
+ self .buffer.discard_message()
280
+
287
281
cdef _process__bind(self , char mtype):
288
282
if mtype == b' E' :
289
283
# ErrorResponse
@@ -780,6 +774,17 @@ cdef class CoreProtocol:
780
774
if self .con_status != CONNECTION_OK:
781
775
raise apg_exc.InternalClientError(' not connected' )
782
776
777
+ cdef WriteBuffer _build_parse_message(self , str stmt_name, str query):
778
+ cdef WriteBuffer buf
779
+
780
+ buf = WriteBuffer.new_message(b' P' )
781
+ buf.write_str(stmt_name, self .encoding)
782
+ buf.write_str(query, self .encoding)
783
+ buf.write_int16(0 )
784
+
785
+ buf.end_message()
786
+ return buf
787
+
783
788
cdef WriteBuffer _build_bind_message(self , str portal_name,
784
789
str stmt_name,
785
790
WriteBuffer bind_data):
@@ -795,6 +800,25 @@ cdef class CoreProtocol:
795
800
buf.end_message()
796
801
return buf
797
802
803
+ cdef WriteBuffer _build_empty_bind_data(self ):
804
+ cdef WriteBuffer buf
805
+ buf = WriteBuffer.new()
806
+ buf.write_int16(0 ) # The number of parameter format codes
807
+ buf.write_int16(0 ) # The number of parameter values
808
+ buf.write_int16(0 ) # The number of result-column format codes
809
+ return buf
810
+
811
+ cdef WriteBuffer _build_execute_message(self , str portal_name,
812
+ int32_t limit):
813
+ cdef WriteBuffer buf
814
+
815
+ buf = WriteBuffer.new_message(b' E' )
816
+ buf.write_str(portal_name, self .encoding) # name of the portal
817
+ buf.write_int32(limit) # number of rows to return; 0 - all
818
+
819
+ buf.end_message()
820
+ return buf
821
+
798
822
# API for subclasses
799
823
800
824
cdef _connect(self ):
@@ -845,12 +869,7 @@ cdef class CoreProtocol:
845
869
self ._ensure_connected()
846
870
self ._set_state(PROTOCOL_PREPARE)
847
871
848
- buf = WriteBuffer.new_message(b' P' )
849
- buf.write_str(stmt_name, self .encoding)
850
- buf.write_str(query, self .encoding)
851
- buf.write_int16(0 )
852
- buf.end_message()
853
- packet = buf
872
+ packet = self ._build_parse_message(stmt_name, query)
854
873
855
874
buf = WriteBuffer.new_message(b' D' )
856
875
buf.write_byte(b' S' )
@@ -872,10 +891,7 @@ cdef class CoreProtocol:
872
891
buf = self ._build_bind_message(portal_name, stmt_name, bind_data)
873
892
packet = buf
874
893
875
- buf = WriteBuffer.new_message(b' E' )
876
- buf.write_str(portal_name, self .encoding) # name of the portal
877
- buf.write_int32(limit) # number of rows to return; 0 - all
878
- buf.end_message()
894
+ buf = self ._build_execute_message(portal_name, limit)
879
895
packet.write_buffer(buf)
880
896
881
897
packet.write_bytes(SYNC_MESSAGE)
@@ -894,11 +910,8 @@ cdef class CoreProtocol:
894
910
895
911
self ._send_bind_message(portal_name, stmt_name, bind_data, limit)
896
912
897
- cdef _bind_execute_many(self , str portal_name, str stmt_name,
898
- object bind_data):
899
-
900
- cdef WriteBuffer buf
901
-
913
+ cdef bint _bind_execute_many(self , str portal_name, str stmt_name,
914
+ object bind_data):
902
915
self ._ensure_connected()
903
916
self ._set_state(PROTOCOL_BIND_EXECUTE_MANY)
904
917
@@ -907,17 +920,88 @@ cdef class CoreProtocol:
907
920
self ._execute_iter = bind_data
908
921
self ._execute_portal_name = portal_name
909
922
self ._execute_stmt_name = stmt_name
923
+ return self ._bind_execute_many_more(True )
910
924
911
- try :
912
- buf = < WriteBuffer> next(bind_data)
913
- except StopIteration :
914
- self ._push_result()
915
- except Exception as e:
916
- self .result_type = RESULT_FAILED
917
- self .result = e
918
- self ._push_result()
919
- else :
920
- self ._send_bind_message(portal_name, stmt_name, buf, 0 )
925
+ cdef bint _bind_execute_many_more(self , bint first = False ):
926
+ cdef:
927
+ WriteBuffer packet
928
+ WriteBuffer buf
929
+ list buffers = []
930
+
931
+ # as we keep sending, the server may return an error early
932
+ if self .result_type == RESULT_FAILED:
933
+ self ._write(SYNC_MESSAGE)
934
+ return False
935
+
936
+ # collect up to four 32KB buffers to send
937
+ # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051
938
+ while len (buffers) < _EXECUTE_MANY_BUF_NUM:
939
+ packet = WriteBuffer.new()
940
+
941
+ # fill one 32KB buffer
942
+ while packet.len() < _EXECUTE_MANY_BUF_SIZE:
943
+ try :
944
+ # grab one item from the input
945
+ buf = < WriteBuffer> next(self ._execute_iter)
946
+
947
+ # reached the end of the input
948
+ except StopIteration :
949
+ if first:
950
+ # if we never send anything, simply set the result
951
+ self ._push_result()
952
+ else :
953
+ # otherwise, append SYNC and send the buffers
954
+ packet.write_bytes(SYNC_MESSAGE)
955
+ buffers.append(packet)
956
+ self ._writelines(buffers)
957
+ return False
958
+
959
+ # error in input, give up the buffers and cleanup
960
+ except Exception as ex:
961
+ self .result_type = RESULT_FAILED
962
+ self .result = ex
963
+ if first:
964
+ self ._push_result()
965
+ elif self .is_in_transaction():
966
+ # we're in an explicit transaction, just SYNC
967
+ self ._write(SYNC_MESSAGE)
968
+ else :
969
+ # In an implicit transaction, if `ignore_till_sync`,
970
+ # `ROLLBACK` will be ignored and `Sync` will restore
971
+ # the state; or the transaction will be rolled back
972
+ # with a warning saying that there was no transaction,
973
+ # but rollback is done anyway, so we could safely
974
+ # ignore this warning.
975
+ # GOTCHA: simple query message will be ignored if
976
+ # `ignore_till_sync` is set.
977
+ buf = self ._build_parse_message(' ' , ' ROLLBACK' )
978
+ buf.write_buffer(self ._build_bind_message(
979
+ ' ' , ' ' , self ._build_empty_bind_data()))
980
+ buf.write_buffer(self ._build_execute_message(' ' , 0 ))
981
+ buf.write_bytes(SYNC_MESSAGE)
982
+ self ._write(buf)
983
+ return False
984
+
985
+ # all good, write to the buffer
986
+ first = False
987
+ packet.write_buffer(
988
+ self ._build_bind_message(
989
+ self ._execute_portal_name,
990
+ self ._execute_stmt_name,
991
+ buf,
992
+ )
993
+ )
994
+ packet.write_buffer(
995
+ self ._build_execute_message(self ._execute_portal_name, 0 ,
996
+ )
997
+ )
998
+
999
+ # collected one buffer
1000
+ buffers.append(packet)
1001
+
1002
+ # write to the wire, and signal the caller for more to send
1003
+ self ._writelines(buffers)
1004
+ return True
921
1005
922
1006
cdef _execute(self , str portal_name, int32_t limit):
923
1007
cdef WriteBuffer buf
@@ -927,10 +1011,7 @@ cdef class CoreProtocol:
927
1011
928
1012
self .result = []
929
1013
930
- buf = WriteBuffer.new_message(b' E' )
931
- buf.write_str(portal_name, self .encoding) # name of the portal
932
- buf.write_int32(limit) # number of rows to return; 0 - all
933
- buf.end_message()
1014
+ buf = self ._build_execute_message(portal_name, limit)
934
1015
935
1016
buf.write_bytes(SYNC_MESSAGE)
936
1017
@@ -1013,6 +1094,9 @@ cdef class CoreProtocol:
1013
1094
cdef _write(self , buf):
1014
1095
raise NotImplementedError
1015
1096
1097
+ cdef _writelines(self , list buffers):
1098
+ raise NotImplementedError
1099
+
1016
1100
cdef _decode_row(self , const char * buf, ssize_t buf_len):
1017
1101
pass
1018
1102
0 commit comments