@@ -27,13 +27,13 @@ cdef class CoreProtocol:
27
27
# type of `scram` is `SCRAMAuthentcation`
28
28
self .scram = None
29
29
30
- # executemany support data
31
- self ._execute_iter = None
32
- self ._execute_portal_name = None
33
- self ._execute_stmt_name = None
34
-
35
30
self ._reset_result()
36
31
32
+ cpdef is_in_transaction(self ):
33
+ # PQTRANS_INTRANS = idle, within transaction block
34
+ # PQTRANS_INERROR = idle, within failed transaction
35
+ return self .xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
36
+
37
37
cdef _read_server_messages(self ):
38
38
cdef:
39
39
char mtype
@@ -263,22 +263,7 @@ cdef class CoreProtocol:
263
263
elif mtype == b' Z' :
264
264
# ReadyForQuery
265
265
self ._parse_msg_ready_for_query()
266
- if self .result_type == RESULT_FAILED:
267
- self ._push_result()
268
- else :
269
- try :
270
- buf = < WriteBuffer> next(self ._execute_iter)
271
- except StopIteration :
272
- self ._push_result()
273
- except Exception as e:
274
- self .result_type = RESULT_FAILED
275
- self .result = e
276
- self ._push_result()
277
- else :
278
- # Next iteration over the executemany() arg sequence
279
- self ._send_bind_message(
280
- self ._execute_portal_name, self ._execute_stmt_name,
281
- buf, 0 )
266
+ self ._push_result()
282
267
283
268
elif mtype == b' I' :
284
269
# EmptyQueryResponse
@@ -780,6 +765,17 @@ cdef class CoreProtocol:
780
765
if self .con_status != CONNECTION_OK:
781
766
raise apg_exc.InternalClientError(' not connected' )
782
767
768
+ cdef WriteBuffer _build_parse_message(self , str stmt_name, str query):
769
+ cdef WriteBuffer buf
770
+
771
+ buf = WriteBuffer.new_message(b' P' )
772
+ buf.write_str(stmt_name, self .encoding)
773
+ buf.write_str(query, self .encoding)
774
+ buf.write_int16(0 )
775
+
776
+ buf.end_message()
777
+ return buf
778
+
783
779
cdef WriteBuffer _build_bind_message(self , str portal_name,
784
780
str stmt_name,
785
781
WriteBuffer bind_data):
@@ -795,6 +791,25 @@ cdef class CoreProtocol:
795
791
buf.end_message()
796
792
return buf
797
793
794
+ cdef WriteBuffer _build_empty_bind_data(self ):
795
+ cdef WriteBuffer buf
796
+ buf = WriteBuffer.new()
797
+ buf.write_int16(0 ) # The number of parameter format codes
798
+ buf.write_int16(0 ) # The number of parameter values
799
+ buf.write_int16(0 ) # The number of result-column format codes
800
+ return buf
801
+
802
+ cdef WriteBuffer _build_execute_message(self , str portal_name,
803
+ int32_t limit):
804
+ cdef WriteBuffer buf
805
+
806
+ buf = WriteBuffer.new_message(b' E' )
807
+ buf.write_str(portal_name, self .encoding) # name of the portal
808
+ buf.write_int32(limit) # number of rows to return; 0 - all
809
+
810
+ buf.end_message()
811
+ return buf
812
+
798
813
# API for subclasses
799
814
800
815
cdef _connect(self ):
@@ -845,12 +860,7 @@ cdef class CoreProtocol:
845
860
self ._ensure_connected()
846
861
self ._set_state(PROTOCOL_PREPARE)
847
862
848
- buf = WriteBuffer.new_message(b' P' )
849
- buf.write_str(stmt_name, self .encoding)
850
- buf.write_str(query, self .encoding)
851
- buf.write_int16(0 )
852
- buf.end_message()
853
- packet = buf
863
+ packet = self ._build_parse_message(stmt_name, query)
854
864
855
865
buf = WriteBuffer.new_message(b' D' )
856
866
buf.write_byte(b' S' )
@@ -872,10 +882,7 @@ cdef class CoreProtocol:
872
882
buf = self ._build_bind_message(portal_name, stmt_name, bind_data)
873
883
packet = buf
874
884
875
- buf = WriteBuffer.new_message(b' E' )
876
- buf.write_str(portal_name, self .encoding) # name of the portal
877
- buf.write_int32(limit) # number of rows to return; 0 - all
878
- buf.end_message()
885
+ buf = self ._build_execute_message(portal_name, limit)
879
886
packet.write_buffer(buf)
880
887
881
888
packet.write_bytes(SYNC_MESSAGE)
@@ -894,30 +901,75 @@ cdef class CoreProtocol:
894
901
895
902
self ._send_bind_message(portal_name, stmt_name, bind_data, limit)
896
903
897
- cdef _bind_execute_many(self , str portal_name, str stmt_name,
898
- object bind_data):
899
-
900
- cdef WriteBuffer buf
901
-
904
+ cdef _execute_many_init(self ):
902
905
self ._ensure_connected()
903
906
self ._set_state(PROTOCOL_BIND_EXECUTE_MANY)
904
907
905
908
self .result = None
906
909
self ._discard_data = True
907
- self ._execute_iter = bind_data
908
- self ._execute_portal_name = portal_name
909
- self ._execute_stmt_name = stmt_name
910
910
911
- try :
912
- buf = < WriteBuffer> next(bind_data)
913
- except StopIteration :
914
- self ._push_result()
915
- except Exception as e:
916
- self .result_type = RESULT_FAILED
917
- self .result = e
911
+ cdef _execute_many_writelines(self , str portal_name, str stmt_name,
912
+ object bind_data):
913
+ cdef:
914
+ WriteBuffer packet
915
+ WriteBuffer buf
916
+ list buffers = []
917
+
918
+ if self .result_type == RESULT_FAILED:
919
+ raise StopIteration (True )
920
+
921
+ while len (buffers) < _EXECUTE_MANY_BUF_NUM:
922
+ packet = WriteBuffer.new()
923
+
924
+ while packet.len() < _EXECUTE_MANY_BUF_SIZE:
925
+ try :
926
+ buf = < WriteBuffer> next(bind_data)
927
+ except StopIteration :
928
+ if packet.len() > 0 :
929
+ buffers.append(packet)
930
+ if len (buffers) > 0 :
931
+ self ._writelines(buffers)
932
+ raise StopIteration (True )
933
+ else :
934
+ raise StopIteration (False )
935
+ except Exception as ex:
936
+ raise StopIteration (ex)
937
+ packet.write_buffer(
938
+ self ._build_bind_message(portal_name, stmt_name, buf))
939
+ packet.write_buffer(
940
+ self ._build_execute_message(portal_name, 0 ))
941
+ buffers.append(packet)
942
+ self ._writelines(buffers)
943
+
944
+ cdef _execute_many_done(self , bint data_sent):
945
+ if data_sent:
946
+ self ._write(SYNC_MESSAGE)
947
+ else :
918
948
self ._push_result()
949
+
950
+ cdef _execute_many_fail(self , object error):
951
+ cdef WriteBuffer buf
952
+
953
+ self .result_type = RESULT_FAILED
954
+ self .result = error
955
+
956
+ # We shall rollback in an implicit transaction to prevent partial
957
+ # commit, while do nothing in an explicit transaction and leaving the
958
+ # error to the user
959
+ if self .is_in_transaction():
960
+ self ._execute_many_done(True )
919
961
else :
920
- self ._send_bind_message(portal_name, stmt_name, buf, 0 )
962
+ # Here if the implicit transaction is in `ignore_till_sync` mode,
963
+ # the `ROLLBACK` will be ignored and `Sync` will restore the state;
964
+ # or else the implicit transaction will be rolled back with a
965
+ # warning saying that there was no transaction, but rollback is
966
+ # done anyway, so we could ignore this warning.
967
+ buf = self ._build_parse_message(' ' , ' ROLLBACK' )
968
+ buf.write_buffer(self ._build_bind_message(
969
+ ' ' , ' ' , self ._build_empty_bind_data()))
970
+ buf.write_buffer(self ._build_execute_message(' ' , 0 ))
971
+ buf.write_bytes(SYNC_MESSAGE)
972
+ self ._write(buf)
921
973
922
974
cdef _execute(self , str portal_name, int32_t limit):
923
975
cdef WriteBuffer buf
@@ -927,10 +979,7 @@ cdef class CoreProtocol:
927
979
928
980
self .result = []
929
981
930
- buf = WriteBuffer.new_message(b' E' )
931
- buf.write_str(portal_name, self .encoding) # name of the portal
932
- buf.write_int32(limit) # number of rows to return; 0 - all
933
- buf.end_message()
982
+ buf = self ._build_execute_message(portal_name, limit)
934
983
935
984
buf.write_bytes(SYNC_MESSAGE)
936
985
@@ -1013,6 +1062,9 @@ cdef class CoreProtocol:
1013
1062
cdef _write(self , buf):
1014
1063
raise NotImplementedError
1015
1064
1065
+ cdef _writelines(self , list buffers):
1066
+ raise NotImplementedError
1067
+
1016
1068
cdef _decode_row(self , const char * buf, ssize_t buf_len):
1017
1069
pass
1018
1070
0 commit comments