Skip to content

Commit d3d5421

Browse files
committed
Implement Record.get()
`Record.get()` allows record objects to better masquerade as dicts. Fixes: #330.
1 parent 0ba8a3a commit d3d5421

File tree

3 files changed

+93
-36
lines changed

3 files changed

+93
-36
lines changed

asyncpg/protocol/record/recordobj.c

+76-33
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,55 @@ record_item(ApgRecordObject *o, Py_ssize_t i)
254254
}
255255

256256

257+
/* Lookup a record value by its name. Return 0 on success, -2 if the
258+
* value was not found (with KeyError set), and -1 on all other errors.
259+
*/
260+
static int
261+
record_item_by_name(ApgRecordObject *o, PyObject *item, PyObject **result)
262+
{
263+
PyObject *mapped;
264+
PyObject *val;
265+
Py_ssize_t i;
266+
267+
mapped = PyObject_GetItem(o->desc->mapping, item);
268+
if (mapped == NULL) {
269+
goto noitem;
270+
}
271+
272+
if (!PyIndex_Check(mapped)) {
273+
Py_DECREF(mapped);
274+
goto error;
275+
}
276+
277+
i = PyNumber_AsSsize_t(mapped, PyExc_IndexError);
278+
Py_DECREF(mapped);
279+
280+
if (i < 0) {
281+
if (PyErr_Occurred())
282+
PyErr_Clear();
283+
goto error;
284+
}
285+
286+
val = record_item(o, i);
287+
if (val == NULL) {
288+
PyErr_Clear();
289+
goto error;
290+
}
291+
292+
*result = val;
293+
294+
return 0;
295+
296+
noitem:
297+
_PyErr_SetKeyError(item);
298+
return -2;
299+
300+
error:
301+
PyErr_SetString(PyExc_RuntimeError, "invalid record descriptor");
302+
return -1;
303+
}
304+
305+
257306
static PyObject *
258307
record_subscript(ApgRecordObject* o, PyObject* item)
259308
{
@@ -299,42 +348,13 @@ record_subscript(ApgRecordObject* o, PyObject* item)
299348
}
300349
}
301350
else {
302-
PyObject *mapped;
303-
mapped = PyObject_GetItem(o->desc->mapping, item);
304-
if (mapped != NULL) {
305-
Py_ssize_t i;
306-
PyObject *result;
307-
308-
if (!PyIndex_Check(mapped)) {
309-
Py_DECREF(mapped);
310-
goto noitem;
311-
}
312-
313-
i = PyNumber_AsSsize_t(mapped, PyExc_IndexError);
314-
Py_DECREF(mapped);
315-
316-
if (i < 0) {
317-
if (PyErr_Occurred()) {
318-
PyErr_Clear();
319-
}
320-
goto noitem;
321-
}
351+
PyObject* result;
322352

323-
result = record_item(o, i);
324-
if (result == NULL) {
325-
PyErr_Clear();
326-
goto noitem;
327-
}
353+
if (record_item_by_name(o, item, &result) < 0)
354+
return NULL;
355+
else
328356
return result;
329-
}
330-
else {
331-
goto noitem;
332-
}
333357
}
334-
335-
noitem:
336-
_PyErr_SetKeyError(item);
337-
return NULL;
338358
}
339359

340360

@@ -483,6 +503,28 @@ record_contains(ApgRecordObject *o, PyObject *arg)
483503
}
484504

485505

506+
static PyObject *
507+
record_get(ApgRecordObject* o, PyObject* args)
508+
{
509+
PyObject *key;
510+
PyObject *defval = Py_None;
511+
PyObject *val = NULL;
512+
int res;
513+
514+
if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &defval))
515+
return NULL;
516+
517+
res = record_item_by_name(o, key, &val);
518+
if (res == -2) {
519+
PyErr_Clear();
520+
Py_INCREF(defval);
521+
val = defval;
522+
}
523+
524+
return val;
525+
}
526+
527+
486528
static PySequenceMethods record_as_sequence = {
487529
(lenfunc)record_length, /* sq_length */
488530
0, /* sq_concat */
@@ -506,6 +548,7 @@ static PyMethodDef record_methods[] = {
506548
{"values", (PyCFunction)record_values, METH_NOARGS},
507549
{"keys", (PyCFunction)record_keys, METH_NOARGS},
508550
{"items", (PyCFunction)record_items, METH_NOARGS},
551+
{"get", (PyCFunction)record_get, METH_VARARGS},
509552
{NULL, NULL} /* sentinel */
510553
};
511554

docs/api/index.rst

+7
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ items either by a numeric index or by a field name:
302302

303303
Return an iterator over the *values* of the record *r*.
304304

305+
.. describe:: get(name[, default])
306+
307+
Return the value for *name* if the record has a field named *name*,
308+
else return *default*. If *default* is not given, return ``None``.
309+
310+
.. versionadded:: 0.18
311+
305312
.. method:: values()
306313

307314
Return an iterator over the record values.

tests/test_record.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_record_gc(self):
4949
mapping = {key: val}
5050
with self.checkref(key, val):
5151
r = Record(mapping, (0,))
52-
with self.assertRaises(KeyError):
52+
with self.assertRaises(RuntimeError):
5353
r[key]
5454
del r
5555

@@ -58,7 +58,7 @@ def test_record_gc(self):
5858
mapping = {key: val}
5959
with self.checkref(key, val):
6060
r = Record(mapping, (0,))
61-
with self.assertRaises(KeyError):
61+
with self.assertRaises(RuntimeError):
6262
r[key]
6363
del r
6464

@@ -90,7 +90,7 @@ def test_record_len_getindex(self):
9090
with self.assertRaisesRegex(KeyError, 'spam'):
9191
Record(None, (1,))['spam']
9292

93-
with self.assertRaisesRegex(KeyError, 'spam'):
93+
with self.assertRaisesRegex(RuntimeError, 'invalid record descriptor'):
9494
Record({'spam': 123}, (1,))['spam']
9595

9696
def test_record_slice(self):
@@ -272,6 +272,13 @@ def test_record_cmp(self):
272272
sorted([r1, r2, r3, r4, r5, r6, r7]),
273273
[r1, r2, r3, r6, r7, r4, r5])
274274

275+
def test_record_get(self):
276+
r = Record(R_AB, (42, 43))
277+
with self.checkref(r):
278+
self.assertEqual(r.get('a'), 42)
279+
self.assertEqual(r.get('nonexistent'), None)
280+
self.assertEqual(r.get('nonexistent', 'default'), 'default')
281+
275282
def test_record_not_pickleable(self):
276283
r = Record(R_A, (42,))
277284
with self.assertRaises(Exception):

0 commit comments

Comments
 (0)