esp32/modsocket: Use all supplied arguments to socket.getaddrinfo().

- Completes a longstanding TODO in the code, to not ignore
  the optional family, type, proto and flags arguments to
  socket.getaddrinfo().

- Note that passing family=socket.AF_INET6 will now cause queries
  to fail (OSError -202). Previously this argument was ignored so
  IPV4 results were returned instead.

- Optional 'type' argument is now always copied into the result. If not
  set, results have type SOCK_STREAM.

- Fixes inconsistency where previously querying mDNS local suffix (.local)
  hostnames returned results with socket type 0 (invalid), but all other
  queries returned results with socket type SOCK_STREAM (regardless of
  'type' argument).

- Optional proto argument is now returned in the result tuple, if supplied.

- Optional flags argument is now passed through to lwIP. lwIP has handling
  for AI_NUMERICHOST, AI_V4MAPPED, AI_PASSIVE (untested, constants for
  these are not currently exposed in the esp32 socket module).

- Also fixes a possible memory leak in an obscure code path
  (lwip_getaddrinfo apparently sometimes returns a result structure with
  address "0.0.0.0" instead of failing, and this structure would have been
  leaked.)

This work was funded through GitHub Sponsors.

Signed-off-by: Angus Gratton <angus@redyak.com.au>
This commit is contained in:
Angus Gratton 2024-01-10 14:11:00 +11:00 committed by Damien George
parent 215a982c14
commit a75ca8a1c0

View File

@ -156,64 +156,67 @@ static inline void check_for_exceptions(void) {
mp_handle_pending(true); mp_handle_pending(true);
} }
// This function mimics lwip_getaddrinfo, with added support for mDNS queries #if MICROPY_HW_ENABLE_MDNS_QUERIES
static int _socket_getaddrinfo3(const char *nodename, const char *servname, // This function mimics lwip_getaddrinfo, but makes an mDNS query
STATIC int mdns_getaddrinfo(const char *host_str, const char *port_str,
const struct addrinfo *hints, struct addrinfo **res) { const struct addrinfo *hints, struct addrinfo **res) {
int host_len = strlen(host_str);
#if MICROPY_HW_ENABLE_MDNS_QUERIES
int nodename_len = strlen(nodename);
const int local_len = sizeof(MDNS_LOCAL_SUFFIX) - 1; const int local_len = sizeof(MDNS_LOCAL_SUFFIX) - 1;
if (nodename_len > local_len if (host_len <= local_len ||
&& strcasecmp(nodename + nodename_len - local_len, MDNS_LOCAL_SUFFIX) == 0) { strcasecmp(host_str + host_len - local_len, MDNS_LOCAL_SUFFIX) != 0) {
// mDNS query
char nodename_no_local[nodename_len - local_len + 1];
memcpy(nodename_no_local, nodename, nodename_len - local_len);
nodename_no_local[nodename_len - local_len] = '\0';
esp_ip4_addr_t addr = {0};
esp_err_t err = mdns_query_a(nodename_no_local, MDNS_QUERY_TIMEOUT_MS, &addr);
if (err != ESP_OK) {
if (err == ESP_ERR_NOT_FOUND) {
*res = NULL;
return 0;
}
*res = NULL;
return err;
}
struct addrinfo *ai = memp_malloc(MEMP_NETDB);
if (ai == NULL) {
*res = NULL;
return EAI_MEMORY;
}
memset(ai, 0, sizeof(struct addrinfo) + sizeof(struct sockaddr_storage));
struct sockaddr_in *sa = (struct sockaddr_in *)((uint8_t *)ai + sizeof(struct addrinfo));
inet_addr_from_ip4addr(&sa->sin_addr, &addr);
sa->sin_family = AF_INET;
sa->sin_len = sizeof(struct sockaddr_in);
sa->sin_port = lwip_htons((u16_t)atoi(servname));
ai->ai_family = AF_INET;
ai->ai_canonname = ((char *)sa + sizeof(struct sockaddr_storage));
memcpy(ai->ai_canonname, nodename, nodename_len + 1);
ai->ai_addrlen = sizeof(struct sockaddr_storage);
ai->ai_addr = (struct sockaddr *)sa;
*res = ai;
return 0; return 0;
} }
#endif
// Normal query // mDNS query
return lwip_getaddrinfo(nodename, servname, hints, res); char host_no_local[host_len - local_len + 1];
memcpy(host_no_local, host_str, host_len - local_len);
host_no_local[host_len - local_len] = '\0';
esp_ip4_addr_t addr = {0};
esp_err_t err = mdns_query_a(host_no_local, MDNS_QUERY_TIMEOUT_MS, &addr);
if (err != ESP_OK) {
if (err == ESP_ERR_NOT_FOUND) {
*res = NULL;
return 0;
}
*res = NULL;
return err;
}
struct addrinfo *ai = memp_malloc(MEMP_NETDB);
if (ai == NULL) {
*res = NULL;
return EAI_MEMORY;
}
memset(ai, 0, sizeof(struct addrinfo) + sizeof(struct sockaddr_storage));
struct sockaddr_in *sa = (struct sockaddr_in *)((uint8_t *)ai + sizeof(struct addrinfo));
inet_addr_from_ip4addr(&sa->sin_addr, &addr);
sa->sin_family = AF_INET;
sa->sin_len = sizeof(struct sockaddr_in);
sa->sin_port = lwip_htons((u16_t)atoi(port_str));
ai->ai_family = AF_INET;
ai->ai_canonname = ((char *)sa + sizeof(struct sockaddr_storage));
memcpy(ai->ai_canonname, host_str, host_len + 1);
ai->ai_addrlen = sizeof(struct sockaddr_storage);
ai->ai_addr = (struct sockaddr *)sa;
ai->ai_socktype = SOCK_STREAM;
if (hints) {
ai->ai_socktype = hints->ai_socktype;
ai->ai_protocol = hints->ai_protocol;
}
*res = ai;
return 0;
} }
#endif // MICROPY_HW_ENABLE_MDNS_QUERIES
static int _socket_getaddrinfo2(const mp_obj_t host, const mp_obj_t portx, struct addrinfo **resp) { static void _getaddrinfo_inner(const mp_obj_t host, const mp_obj_t portx,
const struct addrinfo hints = { const struct addrinfo *hints, struct addrinfo **res) {
.ai_family = AF_INET, int retval = 0;
.ai_socktype = SOCK_STREAM,
}; *res = NULL;
mp_obj_t port = portx; mp_obj_t port = portx;
if (mp_obj_is_integer(port)) { if (mp_obj_is_integer(port)) {
@ -231,27 +234,37 @@ static int _socket_getaddrinfo2(const mp_obj_t host, const mp_obj_t portx, struc
} }
MP_THREAD_GIL_EXIT(); MP_THREAD_GIL_EXIT();
int res = _socket_getaddrinfo3(host_str, port_str, &hints, resp);
#if MICROPY_HW_ENABLE_MDNS_QUERIES
retval = mdns_getaddrinfo(host_str, port_str, hints, res);
#endif
if (retval == 0 && *res == NULL) {
// Normal query
retval = lwip_getaddrinfo(host_str, port_str, hints, res);
}
MP_THREAD_GIL_ENTER(); MP_THREAD_GIL_ENTER();
// Per docs: instead of raising gaierror getaddrinfo raises negative error number // Per docs: instead of raising gaierror getaddrinfo raises negative error number
if (res != 0) { if (retval != 0) {
mp_raise_OSError(res > 0 ? -res : res); mp_raise_OSError(retval > 0 ? -retval : retval);
} }
// Somehow LwIP returns a resolution of 0.0.0.0 for failed lookups, traced it as far back // Somehow LwIP returns a resolution of 0.0.0.0 for failed lookups, traced it as far back
// as netconn_gethostbyname_addrtype returning OK instead of error. // as netconn_gethostbyname_addrtype returning OK instead of error.
if (*resp == NULL || if (*res == NULL ||
(strcmp(resp[0]->ai_canonname, "0.0.0.0") == 0 && strcmp(host_str, "0.0.0.0") != 0)) { (strcmp(res[0]->ai_canonname, "0.0.0.0") == 0 && strcmp(host_str, "0.0.0.0") != 0)) {
lwip_freeaddrinfo(*res);
mp_raise_OSError(-2); // name or service not known mp_raise_OSError(-2); // name or service not known
} }
return res; assert(retval == 0 && *res != NULL);
} }
STATIC void _socket_getaddrinfo(const mp_obj_t addrtuple, struct addrinfo **resp) { STATIC void _socket_getaddrinfo(const mp_obj_t addrtuple, struct addrinfo **resp) {
mp_obj_t *elem; mp_obj_t *elem;
mp_obj_get_array_fixed_n(addrtuple, 2, &elem); mp_obj_get_array_fixed_n(addrtuple, 2, &elem);
_socket_getaddrinfo2(elem[0], elem[1], resp); _getaddrinfo_inner(elem[0], elem[1], NULL, resp);
} }
STATIC mp_obj_t socket_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) { STATIC mp_obj_t socket_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
@ -897,10 +910,32 @@ STATIC MP_DEFINE_CONST_OBJ_TYPE(
); );
STATIC mp_obj_t esp_socket_getaddrinfo(size_t n_args, const mp_obj_t *args) { STATIC mp_obj_t esp_socket_getaddrinfo(size_t n_args, const mp_obj_t *args) {
// TODO support additional args beyond the first two struct addrinfo hints = { };
struct addrinfo *res = NULL; struct addrinfo *res = NULL;
_socket_getaddrinfo2(args[0], args[1], &res);
// Optional args: family=0, type=0, proto=0, flags=0, where 0 is "least narrow"
if (n_args > 2) {
hints.ai_family = mp_obj_get_int(args[2]);
}
if (n_args > 3) {
hints.ai_socktype = mp_obj_get_int(args[3]);
}
if (hints.ai_socktype == 0) {
// This is slightly different to CPython with POSIX getaddrinfo. In
// CPython, calling socket.getaddrinfo() with socktype=0 returns any/all
// supported SocketKind values. Here, lwip_getaddrinfo() will echo
// whatever socktype was supplied to the caller. Rather than returning 0
// (invalid in a result), make it SOCK_STREAM.
hints.ai_socktype = SOCK_STREAM;
}
if (n_args > 4) {
hints.ai_protocol = mp_obj_get_int(args[4]);
}
if (n_args > 5) {
hints.ai_flags = mp_obj_get_int(args[5]);
}
_getaddrinfo_inner(args[0], args[1], &hints, &res);
mp_obj_t ret_list = mp_obj_new_list(0, NULL); mp_obj_t ret_list = mp_obj_new_list(0, NULL);
for (struct addrinfo *resi = res; resi; resi = resi->ai_next) { for (struct addrinfo *resi = res; resi; resi = resi->ai_next) {
@ -927,9 +962,7 @@ STATIC mp_obj_t esp_socket_getaddrinfo(size_t n_args, const mp_obj_t *args) {
mp_obj_list_append(ret_list, mp_obj_new_tuple(5, addrinfo_objs)); mp_obj_list_append(ret_list, mp_obj_new_tuple(5, addrinfo_objs));
} }
if (res) { lwip_freeaddrinfo(res);
lwip_freeaddrinfo(res);
}
return ret_list; return ret_list;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(esp_socket_getaddrinfo_obj, 2, 6, esp_socket_getaddrinfo); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(esp_socket_getaddrinfo_obj, 2, 6, esp_socket_getaddrinfo);