Skip to content

Commit 6de82d2

Browse files
committed
PYTHON-1874 Fix coll.aggregate() when result is missing the "ns" field
1 parent 0b72f88 commit 6de82d2

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

pymongo/aggregation.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def _cursor_namespace(self):
7777
"""The namespace in which the aggregate command is run."""
7878
raise NotImplementedError
7979

80+
@property
81+
def _cursor_collection(self, cursor_doc):
82+
"""The Collection used for the aggregate command cursor."""
83+
raise NotImplementedError
84+
8085
@property
8186
def _database(self):
8287
"""The database against which the aggregation command is run."""
@@ -152,18 +157,9 @@ def get_cursor(self, session, server, sock_info, slave_ok):
152157
"ns": self._cursor_namespace,
153158
}
154159

155-
# Get collection to target with cursor.
156-
ns = cursor["ns"]
157-
_, collname = ns.split(".", 1)
158-
aggregation_collection = self._database.get_collection(
159-
collname, codec_options=self._target.codec_options,
160-
read_preference=read_preference,
161-
write_concern=self._target.write_concern,
162-
read_concern=self._target.read_concern)
163-
164160
# Create and return cursor instance.
165161
return self._cursor_class(
166-
aggregation_collection, cursor, sock_info.address,
162+
self._cursor_collection(cursor), cursor, sock_info.address,
167163
batch_size=self._batch_size or 0,
168164
max_await_time_ms=self._max_await_time_ms,
169165
session=session, explicit_session=self._explicit_session)
@@ -188,6 +184,10 @@ def _aggregation_target(self):
188184
def _cursor_namespace(self):
189185
return self._target.full_name
190186

187+
def _cursor_collection(self, cursor):
188+
"""The Collection used for the aggregate command cursor."""
189+
return self._target
190+
191191
@property
192192
def _database(self):
193193
return self._target.database
@@ -209,16 +209,24 @@ def _aggregation_target(self):
209209

210210
@property
211211
def _cursor_namespace(self):
212-
return "%s.%s.aggregate" % (self._target.name, "$cmd")
212+
return "%s.$cmd.aggregate" % (self._target.name,)
213213

214214
@property
215215
def _database(self):
216216
return self._target
217217

218+
def _cursor_collection(self, cursor):
219+
"""The Collection used for the aggregate command cursor."""
220+
# Collection level aggregate may not always return the "ns" field
221+
# according to our MockupDB tests. Let's handle that case for db level
222+
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
223+
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
224+
return self._database[collname]
225+
218226
@staticmethod
219227
def _check_compat(sock_info):
220228
# Older server version don't raise a descriptive error, so we raise
221229
# one instead.
222230
if not sock_info.max_wire_version >= 6:
223-
err_msg = "Database.aggregation is only supported on MongoDB 3.6+."
231+
err_msg = "Database.aggregate() is only supported on MongoDB 3.6+."
224232
raise ConfigurationError(err_msg)

test/test_database.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ def test_database_aggregation_fake_cursor(self):
10541054

10551055
@client_context.require_version_max(3, 6, 0, -1)
10561056
def test_database_aggregation_unsupported(self):
1057-
err_msg = "Database.aggregation is only supported on MongoDB 3.6\+."
1057+
err_msg = "Database.aggregate\(\) is only supported on MongoDB 3.6\+."
10581058
with self.assertRaisesRegex(ConfigurationError, err_msg):
10591059
with self.admin.aggregate(self.pipeline) as _:
10601060
pass

0 commit comments

Comments
 (0)