Skip to content

Implement Record.get() #331

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 83 additions & 33 deletions asyncpg/protocol/record/recordobj.c
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,62 @@ record_item(ApgRecordObject *o, Py_ssize_t i)
}


typedef enum item_by_name_result {
APG_ITEM_FOUND = 0,
APG_ERROR = -1,
APG_ITEM_NOT_FOUND = -2
} item_by_name_result_t;


/* Lookup a record value by its name. Return 0 on success, -2 if the
* value was not found (with KeyError set), and -1 on all other errors.
*/
static item_by_name_result_t
record_item_by_name(ApgRecordObject *o, PyObject *item, PyObject **result)
{
PyObject *mapped;
PyObject *val;
Py_ssize_t i;

mapped = PyObject_GetItem(o->desc->mapping, item);
if (mapped == NULL) {
goto noitem;
}

if (!PyIndex_Check(mapped)) {
Py_DECREF(mapped);
goto error;
}

i = PyNumber_AsSsize_t(mapped, PyExc_IndexError);
Py_DECREF(mapped);

if (i < 0) {
if (PyErr_Occurred())
PyErr_Clear();
goto error;
}

val = record_item(o, i);
if (val == NULL) {
PyErr_Clear();
goto error;
}

*result = val;

return APG_ITEM_FOUND;

noitem:
PyErr_SetObject(PyExc_KeyError, item);
return APG_ITEM_NOT_FOUND;

error:
PyErr_SetString(PyExc_RuntimeError, "invalid record descriptor");
return APG_ERROR;
}


static PyObject *
record_subscript(ApgRecordObject* o, PyObject* item)
{
Expand Down Expand Up @@ -299,42 +355,13 @@ record_subscript(ApgRecordObject* o, PyObject* item)
}
}
else {
PyObject *mapped;
mapped = PyObject_GetItem(o->desc->mapping, item);
if (mapped != NULL) {
Py_ssize_t i;
PyObject *result;

if (!PyIndex_Check(mapped)) {
Py_DECREF(mapped);
goto noitem;
}

i = PyNumber_AsSsize_t(mapped, PyExc_IndexError);
Py_DECREF(mapped);

if (i < 0) {
if (PyErr_Occurred()) {
PyErr_Clear();
}
goto noitem;
}
PyObject* result;

result = record_item(o, i);
if (result == NULL) {
PyErr_Clear();
goto noitem;
}
if (record_item_by_name(o, item, &result) < 0)
return NULL;
else
return result;
}
else {
goto noitem;
}
}

noitem:
_PyErr_SetKeyError(item);
return NULL;
}


Expand Down Expand Up @@ -483,6 +510,28 @@ record_contains(ApgRecordObject *o, PyObject *arg)
}


static PyObject *
record_get(ApgRecordObject* o, PyObject* args)
{
PyObject *key;
PyObject *defval = Py_None;
PyObject *val = NULL;
int res;

if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &defval))
return NULL;

res = record_item_by_name(o, key, &val);
if (res == APG_ITEM_NOT_FOUND) {
PyErr_Clear();
Py_INCREF(defval);
val = defval;
}

return val;
}


static PySequenceMethods record_as_sequence = {
(lenfunc)record_length, /* sq_length */
0, /* sq_concat */
Expand All @@ -506,6 +555,7 @@ static PyMethodDef record_methods[] = {
{"values", (PyCFunction)record_values, METH_NOARGS},
{"keys", (PyCFunction)record_keys, METH_NOARGS},
{"items", (PyCFunction)record_items, METH_NOARGS},
{"get", (PyCFunction)record_get, METH_VARARGS},
{NULL, NULL} /* sentinel */
};

Expand Down
7 changes: 7 additions & 0 deletions docs/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,13 @@ items either by a numeric index or by a field name:

Return an iterator over the *values* of the record *r*.

.. describe:: get(name[, default])

Return the value for *name* if the record has a field named *name*,
else return *default*. If *default* is not given, return ``None``.

.. versionadded:: 0.18

.. method:: values()

Return an iterator over the record values.
Expand Down
13 changes: 10 additions & 3 deletions tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_record_gc(self):
mapping = {key: val}
with self.checkref(key, val):
r = Record(mapping, (0,))
with self.assertRaises(KeyError):
with self.assertRaises(RuntimeError):
r[key]
del r

Expand All @@ -58,7 +58,7 @@ def test_record_gc(self):
mapping = {key: val}
with self.checkref(key, val):
r = Record(mapping, (0,))
with self.assertRaises(KeyError):
with self.assertRaises(RuntimeError):
r[key]
del r

Expand Down Expand Up @@ -90,7 +90,7 @@ def test_record_len_getindex(self):
with self.assertRaisesRegex(KeyError, 'spam'):
Record(None, (1,))['spam']

with self.assertRaisesRegex(KeyError, 'spam'):
with self.assertRaisesRegex(RuntimeError, 'invalid record descriptor'):
Record({'spam': 123}, (1,))['spam']

def test_record_slice(self):
Expand Down Expand Up @@ -272,6 +272,13 @@ def test_record_cmp(self):
sorted([r1, r2, r3, r4, r5, r6, r7]),
[r1, r2, r3, r6, r7, r4, r5])

def test_record_get(self):
r = Record(R_AB, (42, 43))
with self.checkref(r):
self.assertEqual(r.get('a'), 42)
self.assertEqual(r.get('nonexistent'), None)
self.assertEqual(r.get('nonexistent', 'default'), 'default')

def test_record_not_pickleable(self):
r = Record(R_A, (42,))
with self.assertRaises(Exception):
Expand Down