diff --git a/src/commonlib/include/commonlib/list.h b/src/commonlib/include/commonlib/list.h index ce74ca34b6..a6c87ceee2 100644 --- a/src/commonlib/include/commonlib/list.h +++ b/src/commonlib/include/commonlib/list.h @@ -9,55 +9,86 @@ #include struct list_node { - struct list_node *next; - struct list_node *prev; + struct _internal_do_not_access_list_node { + struct list_node *next; + struct list_node *prev; + } _internal_do_not_access; }; +// These macros do NOT belong to the public API. +#define NEXT(ptr) ((ptr)->_internal_do_not_access.next) +#define PREV(ptr) ((ptr)->_internal_do_not_access.prev) + +/* Initialize a circular list, with `head` being a placeholder head node. */ +void _list_init(struct list_node *head); + // Remove list_node node from the doubly linked list it's a part of. void list_remove(struct list_node *node); // Insert list_node node after list_node after in a doubly linked list. void list_insert_after(struct list_node *node, struct list_node *after); // Insert list_node node before list_node before in a doubly linked list. +// `before` must not be the placeholder head node. void list_insert_before(struct list_node *node, struct list_node *before); // Append the node to the end of the list. -void list_append(struct list_node *node, struct list_node *head); +static inline void list_append(struct list_node *node, struct list_node *head) +{ + _list_init(head); + /* With a circular list, we just need to insert before the head. */ + list_insert_before(node, head); +} // Return if the list is empty. static inline bool list_is_empty(const struct list_node *head) { - return !head->next; + return !NEXT(head) || NEXT(head) == head; } // Get next node. static inline struct list_node *list_next(const struct list_node *node, const struct list_node *head) { - return node->next; + return NEXT(node) == head ? NULL : NEXT(node); }; // Get prev node. static inline struct list_node *list_prev(const struct list_node *node, const struct list_node *head) { - return node->prev == head ? NULL : node->prev; + return PREV(node) == head ? NULL : PREV(node); }; // Get first node. static inline struct list_node *list_first(const struct list_node *head) { - return list_is_empty(head) ? NULL : head->next; + return list_next(head, head); } // Get last node. -struct list_node *list_last(const struct list_node *head); +static inline struct list_node *list_last(const struct list_node *head) +{ + return list_prev(head, head); +} // Get the number of list elements. size_t list_length(const struct list_node *head); -#define list_for_each(ptr, head, member) \ - for ((ptr) = container_of((head).next, typeof(*(ptr)), member); \ - (uintptr_t)ptr + (uintptr_t)offsetof(typeof(*(ptr)), member); \ - (ptr) = container_of((ptr)->member.next, \ +/* + * Explanation of `ptr` initialization: + * 1. head.next != NULL: This means the list isn't empty. As the implementation ensures that + * _list_init() is called when the very first element is added, we can safely assume that + * the list is circular, and hence set `ptr` to the 1st element. + * 2. head.next == NULL: This means the list is empty, and _list_init() hasn't been called. + * As the `head` arg might be const, we cannot simply call _list_init() here. Instead, we set + * `ptr` to a special value such that `&(ptr->member) == &head`, causing the loop to + * terminate immediately. + */ +#define list_for_each(ptr, head, member) \ + for ((ptr) = container_of((head)._internal_do_not_access.next ?: &(head), typeof(*(ptr)), member); \ + &((ptr)->member) != &(head); \ + (ptr) = container_of((ptr)->member._internal_do_not_access.next, \ typeof(*(ptr)), member)) +#undef NEXT +#undef PREV + #endif /* __COMMONLIB_LIST_H__ */ diff --git a/src/commonlib/list.c b/src/commonlib/list.c index 1af92c8ac5..93a9b12887 100644 --- a/src/commonlib/list.c +++ b/src/commonlib/list.c @@ -1,51 +1,47 @@ /* Taken from depthcharge: src/base/list.c */ /* SPDX-License-Identifier: GPL-2.0-or-later */ +#include #include +#define NEXT(ptr) ((ptr)->_internal_do_not_access.next) +#define PREV(ptr) ((ptr)->_internal_do_not_access.prev) + +void _list_init(struct list_node *head) +{ + if (!NEXT(head)) { + assert(!PREV(head)); + PREV(head) = NEXT(head) = head; + } +} + void list_remove(struct list_node *node) { - if (node->prev) - node->prev->next = node->next; - if (node->next) - node->next->prev = node->prev; + /* Cannot remove the head node. */ + assert(PREV(node) && NEXT(node)); + NEXT(PREV(node)) = NEXT(node); + PREV(NEXT(node)) = PREV(node); } void list_insert_after(struct list_node *node, struct list_node *after) { - node->next = after->next; - node->prev = after; - after->next = node; - if (node->next) - node->next->prev = node; + /* Check uninitialized head node. */ + if (!PREV(after)) + _list_init(after); + NEXT(node) = NEXT(after); + PREV(node) = after; + NEXT(after) = node; + PREV(NEXT(node)) = node; } void list_insert_before(struct list_node *node, struct list_node *before) { - node->prev = before->prev; - node->next = before; - before->prev = node; - if (node->prev) - node->prev->next = node; -} - -void list_append(struct list_node *node, struct list_node *head) -{ - while (head->next) - head = head->next; - - list_insert_after(node, head); -} - -struct list_node *list_last(const struct list_node *head) -{ - if (!head->next) - return NULL; - - struct list_node *ptr = head->next; - while (ptr->next) - ptr = ptr->next; - return ptr; + /* `before` cannot be an uninitialized head node. */ + assert(PREV(before)); + PREV(node) = PREV(before); + NEXT(node) = before; + PREV(before) = node; + NEXT(PREV(node)) = node; } size_t list_length(const struct list_node *head) @@ -60,3 +56,6 @@ size_t list_length(const struct list_node *head) return len; } + +#undef NEXT +#undef PREV diff --git a/tests/commonlib/list-test.c b/tests/commonlib/list-test.c index 568fa1d3dd..e4dbd8e14a 100644 --- a/tests/commonlib/list-test.c +++ b/tests/commonlib/list-test.c @@ -45,7 +45,7 @@ static void test_list_one_node(void **state) static void test_list_insert_after(void **state) { int i = 0; - struct list_node head = { .prev = NULL, .next = NULL }; + struct list_node head = {}; struct test_container *c1 = (struct test_container *)malloc(sizeof(*c1)); struct test_container *c2 = (struct test_container *)malloc(sizeof(*c2)); struct test_container *c3 = (struct test_container *)malloc(sizeof(*c2)); @@ -86,7 +86,7 @@ static void test_list_insert_after(void **state) static void test_list_insert_before(void **state) { int i = 0; - struct list_node head = { .prev = NULL, .next = NULL }; + struct list_node head = {}; struct test_container *c1 = (struct test_container *)malloc(sizeof(*c1)); struct test_container *c2 = (struct test_container *)malloc(sizeof(*c2)); struct test_container *c3 = (struct test_container *)malloc(sizeof(*c2)); @@ -105,7 +105,6 @@ static void test_list_insert_before(void **state) list_insert_before(&c2->list_node, &c3->list_node); list_insert_before(&c1->list_node, &c2->list_node); - list_for_each(ptr, head, list_node) { assert_int_equal(values[i], ptr->value); i++; @@ -120,9 +119,17 @@ static void test_list_insert_before(void **state) free(c1); } +static void test_list_insert_before_head(void **state) +{ + struct list_node head = {}; + struct test_container c = {}; + + expect_assert_failure(list_insert_before(&c.list_node, &head)); +} + static void test_list_remove(void **state) { - struct list_node head = { .prev = NULL, .next = NULL }; + struct list_node head = {}; struct test_container *c1 = (struct test_container *)malloc(sizeof(*c1)); struct test_container *c2 = (struct test_container *)malloc(sizeof(*c2)); @@ -141,6 +148,12 @@ static void test_list_remove(void **state) free(c1); } +static void test_list_remove_head(void **state) +{ + struct list_node head = {}; + expect_assert_failure(list_remove(&head)); +} + static void test_list_append(void **state) { size_t idx; @@ -170,7 +183,9 @@ int main(void) cmocka_unit_test(test_list_one_node), cmocka_unit_test(test_list_insert_after), cmocka_unit_test(test_list_insert_before), + cmocka_unit_test(test_list_insert_before_head), cmocka_unit_test(test_list_remove), + cmocka_unit_test(test_list_remove_head), cmocka_unit_test(test_list_append), };