Skip to content

Commit b190169

Browse files
committed
support setting a detatched HEAD
1 parent a063867 commit b190169

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

src/repository.c

+21-11
Original file line numberDiff line numberDiff line change
@@ -179,21 +179,31 @@ Repository_head__get__(Repository *self)
179179
}
180180

181181
int
182-
Repository_head__set__(Repository *self, PyObject *py_refname)
182+
Repository_head__set__(Repository *self, PyObject *py_val)
183183
{
184184
int err;
185-
const char *refname;
186-
PyObject *trefname;
185+
if (PyObject_TypeCheck(py_val, &OidType)) {
186+
git_oid oid;
187+
py_oid_to_git_oid(py_val, &oid);
188+
err = git_repository_set_head_detached(self->repo, &oid, NULL, NULL);
189+
if (err < 0) {
190+
Error_set(err);
191+
return -1;
192+
}
193+
} else {
194+
const char *refname;
195+
PyObject *trefname;
187196

188-
refname = py_str_borrow_c_str(&trefname, py_refname, NULL);
189-
if (refname == NULL)
190-
return -1;
197+
refname = py_str_borrow_c_str(&trefname, py_val, NULL);
198+
if (refname == NULL)
199+
return -1;
191200

192-
err = git_repository_set_head(self->repo, refname, NULL, NULL);
193-
Py_DECREF(trefname);
194-
if (err < 0) {
195-
Error_set_str(err, refname);
196-
return -1;
201+
err = git_repository_set_head(self->repo, refname, NULL, NULL);
202+
Py_DECREF(trefname);
203+
if (err < 0) {
204+
Error_set_str(err, refname);
205+
return -1;
206+
}
197207
}
198208

199209
return 0;

test/test_repository.py

+9
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ def test_head(self):
7070
self.assertFalse(self.repo.head_is_unborn)
7171
self.assertFalse(self.repo.head_is_detached)
7272

73+
def test_set_head(self):
74+
# Test setting a detatched HEAD.
75+
self.repo.head = Oid(hex=PARENT_SHA)
76+
self.assertEqual(self.repo.head.target.hex, PARENT_SHA)
77+
# And test setting a normal HEAD.
78+
self.repo.head = "refs/heads/master"
79+
self.assertEqual(self.repo.head.name, "refs/heads/master")
80+
self.assertEqual(self.repo.head.target.hex, HEAD_SHA)
81+
7382
def test_read(self):
7483
self.assertRaises(TypeError, self.repo.read, 123)
7584
self.assertRaisesWithArg(KeyError, '1' * 40, self.repo.read, '1' * 40)

0 commit comments

Comments
 (0)