1#include <errno.h>
  2#include <fcntl.h>
  3#include <netdb.h>
  4#include <sys/un.h>
  5#include <unistd.h>
  6
  7#include "include/pw_args.h"
  8#include "include/pwlib/netutils.h"
  9#include "include/pwlib/socket.h"
 10#include "include/pw_utf.h"
 11#include "src/pw_interfaces_internal.h"
 12#include "src/types/struct_internal.h"
 13
 14// status codes
 15uint16_t PweBadAddressFamily = 0;
 16uint16_t PweBadIpAddress = 0;
 17uint16_t PweBadPort = 0;
 18uint16_t PweHostAddressExpected = 0;
 19uint16_t PweAddressFamilyMismatch = 0;
 20uint16_t PweSocketNameTooLong = 0;
 21uint16_t PweMissingNetmask = 0;
 22uint16_t PweBadNetmask = 0;
 23uint16_t PwePortUnspecified = 0;
 24
 25// socket data
 26
 27PW_STRUCT(_PwSocket) {
 28    int sock;
 29    int listen_backlog;  // 0: not listening
 30    uint16_t new_sock_type;  // for accept, to create socket of desired type
 31    bool own_fd;
 32
 33    // args the socket() was called with
 34    int domain;
 35    int type;
 36    int proto;
 37
 38    _PwValue local_addr;   // set when bind returns success
 39    _PwValue remote_addr;  // set when connect returns success or "in progress"
 40};
 41
 42
 43static void close_socket(_PwSocket* s)
 44{
 45    if (s->sock != -1 && s->own_fd) {
 46        close(s->sock);
 47        s->sock = -1;
 48    }
 49    pw_destroy(&s->local_addr);
 50    pw_destroy(&s->remote_addr);
 51}
 52
 53/****************************************************************
 54 * SockAddr basic interface
 55 */
 56
 57uint16_t PwTypeId_SockAddr = 0;
 58
 59static bool sockaddr_hash(PwMethod_Basic_hash* mthis, PwValuePtr self, PwHashContext* ctx, _PwCompoundChain* tail)
 60{
 61    _PwSockAddrData* sa = pw_this_data(self);
 62
 63    _pw_hash_uint64(ctx, self->type_id);
 64    _pw_hash_uint64(ctx, sa->addr.ss_family);
 65    switch (sa->addr.ss_family) {
 66        case AF_LOCAL: {
 67            char* path = ((struct sockaddr_un*) &sa->addr)->sun_path;
 68            for (;;) {
 69                unsigned char chr = *path++;
 70                if (chr == 0) {
 71                    break;
 72                }
 73                _pw_hash_uint64(ctx, chr);
 74            }
 75            break;
 76        }
 77        case AF_INET:
 78            _pw_hash_uint64(ctx, ((struct sockaddr_in*) &sa->addr)->sin_addr.s_addr);
 79            _pw_hash_uint64(ctx, ((struct sockaddr_in*) &sa->addr)->sin_port);
 80            break;
 81        case AF_INET6: {
 82            uint64_t* addr = (uint64_t*) &((struct sockaddr_in6*) &sa->addr)->sin6_addr;
 83            _pw_hash_uint64(ctx, *addr++);
 84            _pw_hash_uint64(ctx, *addr);
 85            _pw_hash_uint64(ctx, ((struct sockaddr_in6*) &sa->addr)->sin6_port);
 86            break;
 87        }
 88        default:
 89            break;
 90    }
 91    _pw_hash_uint64(ctx, sa->netmask);
 92    return true;
 93}
 94
 95static bool dump_sockaddr(FILE* fp, PwValuePtr sockaddr)
 96{
 97    _PwSockAddrData* sa = _pw_get_struct_ptr(sockaddr, PwTypeId_SockAddr);
 98
 99    const char* addr;
100    char addr_buf[1024];
101    in_port_t port = 0;
102    switch (sa->addr.ss_family) {
103        case AF_UNSPEC:
104            addr = "[unspecified]";
105            break;
106        case AF_LOCAL:
107            addr = ((struct sockaddr_un*) &sa->addr)->sun_path;
108            break;
109        case AF_INET:
110            port = ntohs(((struct sockaddr_in*) &sa->addr)->sin_port);
111            addr = inet_ntop(AF_INET, &((struct sockaddr_in*) &sa->addr)->sin_addr, addr_buf, sizeof(addr_buf));
112            break;
113        case AF_INET6:
114            port = ntohs(((struct sockaddr_in6*) &sa->addr)->sin6_port);
115            addr = inet_ntop(AF_INET6, &((struct sockaddr_in6*) &sa->addr)->sin6_addr, addr_buf, sizeof(addr_buf));
116            break;
117        default:
118            addr = "<not supported>";
119            break;
120    }
121    if (!addr) {
122        addr = "<error>";
123    }
124    if (port) {
125        if (sa->addr.ss_family == AF_INET6) {
126            fprintf(fp, " [%s]:%u", addr, port);
127        } else {
128            fprintf(fp, " %s:%u", addr, port);
129        }
130    } else {
131        fputc(' ', fp);
132        fputs(addr, fp);
133    }
134    if (sa->netmask) {
135        fprintf(fp, "/%u\n", sa->netmask);
136    } else {
137        fputc('\n', fp);
138    }
139    return true;
140}
141
142static bool sockaddr_dump(PwMethod_Basic_dump* mthis, PwValuePtr self, FILE* fp, int indent, _PwCompoundChain* tail)
143{
144    if (!pw_super(mthis, self, fp, indent, tail)) {
145        return false;
146    }
147    _pw_print_indent(fp, indent);
148    dump_sockaddr(fp, self);
149    return true;
150}
151
152static PwInterface_Basic sockaddr_basic_interface = {
153    .hash = { .func = sockaddr_hash },
154    .dump = { .func = sockaddr_dump }
155};
156
157/****************************************************************
158 * Socket type and basic intrerface
159 */
160
161uint16_t PwTypeId_Socket = 0;
162
163static bool socket_create(PwMethod_Basic_create* mthis, PwValuePtr result, PwCtorArgs* ctor_args)
164{
165    if (!pw_super(mthis, result, ctor_args)) {
166        return false;
167    }
168
169    _PwSocket* s = pw_this_data(result);
170
171    PwSocketCtorArgs* args = pw_this_ctor_args();
172
173    if (!args) {
174        // special case for internal use in accept method
175        s->sock = -1;
176        return true;
177    }
178
179    s->sock = socket(args->domain, args->type, args->protocol);
180    if (s->sock != -1) {
181
182        s->own_fd = true;
183        s->domain = args->domain;
184        s->type   = args->type;
185        s->proto  = args->protocol;
186
187        return true;
188    }
189    int _errno = errno;
190    if (!pw_super_call(destroy, mthis, result, nullptr)) { /* no op */ }
191    pw_set_status(PwErrno(_errno));
192    return false;
193}
194
195static bool socket_destroy(PwMethod_Basic_destroy* mthis, PwValuePtr self, _PwCompoundChain* tail)
196{
197    PwValuePtr value_seen = _pw_on_chain(self, tail);
198    if (value_seen) {
199        return true;
200    }
201    close_socket(pw_this_data(self));
202    return pw_super(mthis, self, tail);
203}
204
205static bool socket_dump(PwMethod_Basic_dump* mthis, PwValuePtr self, FILE* fp, int indent, _PwCompoundChain* tail)
206{
207    if (!pw_super(mthis, self, fp, indent, tail)) {
208        return false;
209    }
210
211    _PwSocket* s = pw_this_data(self);
212
213    // domain to string
214
215    char* domain;
216    char domain_other[32];
217    switch (s->domain) {
218        case AF_UNSPEC:  domain = "AF_UNSPEC"; break;
219        case AF_LOCAL:   domain = "AF_LOCAL"; break;
220        case AF_INET:    domain = "AF_INET"; break;
221        case AF_INET6:   domain = "AF_INET6"; break;
222        case AF_NETLINK: domain = "AF_NETLINK"; break;
223        case AF_PACKET:  domain = "AF_PACKET"; break;
224        default:
225            snprintf(domain_other, sizeof(domain_other), "AF_(%d)", s->domain);
226            domain = domain_other;
227            break;
228    }
229
230    // socket type to string
231
232    char* type;
233    char type_other[32];
234    char type_buf[80];
235    int t = s->type & ~(SOCK_CLOEXEC | SOCK_NONBLOCK);
236    switch (t) {
237        case SOCK_STREAM: type = "SOCK_STREAM"; break;
238        case SOCK_DGRAM:  type = "SOCK_DGRAM"; break;
239        case SOCK_RAW:    type = "SOCK_RAW"; break;
240        default:
241            snprintf(type_other, sizeof(type_other), "SOCK_(%d)", t);
242            type = type_other;
243            break;
244    }
245    if (s->type & (SOCK_CLOEXEC | SOCK_NONBLOCK)) {
246        strcpy(type_buf, type);
247        strcat(type_buf, "{");
248        if (s->type & SOCK_CLOEXEC) {
249            strcat(type_buf, "SOCK_CLOEXEC");
250        }
251        if (s->type & SOCK_NONBLOCK) {
252            if (s->type & SOCK_CLOEXEC) {
253                strcat(type_buf, ",");
254            }
255            strcat(type_buf, "SOCK_NONBLOCK");
256        }
257        strcat(type_buf, "}");
258        type = type_buf;
259    }
260
261    // proto to string
262
263    char* protocol;
264    struct protoent proto;
265    struct protoent *proto_result;
266    char proto_buf[1024];
267    if (getprotobynumber_r(s->proto, &proto, proto_buf, sizeof(proto_buf), &proto_result) == 0) {
268        protocol = proto_result->p_name;
269    } else {
270        snprintf(proto_buf, sizeof(proto_buf), "%d", s->proto);
271        protocol = proto_buf;
272    }
273
274    // listen backlog to string
275
276    char* listening = "";
277    char listening_buf[128];
278    if (s->listen_backlog) {
279        snprintf(listening_buf, sizeof(listening_buf), ", listening (backlog=%d, new_type=%u)", s->listen_backlog, s->new_sock_type);
280        listening = listening_buf;
281    }
282
283    // dump
284
285    _pw_print_indent(fp, indent);
286    fprintf(fp, "socket(%s, %s, %s) fd=%d%s\n", domain, type, protocol, s->sock, listening);
287
288    if (pw_is_sockaddr(&s->local_addr)) {
289        _pw_print_indent(fp, indent);
290        fputs("Local address: ", fp);
291        dump_sockaddr(fp, &s->local_addr);
292    }
293    if (pw_is_sockaddr(&s->remote_addr)) {
294        _pw_print_indent(fp, indent);
295        fputs("Remote address: ", fp);
296        dump_sockaddr(fp, &s->remote_addr);
297    }
298    return true;
299}
300
301static PwInterface_Basic socket_basic_interface = {
302    .create  = { .func = socket_create },
303    .destroy = { .func = socket_destroy },
304    .dump    = { .func = socket_dump }
305};
306
307/****************************************************************
308 * Socket interface
309 */
310
311uint16_t PwInterfaceId_Socket = 0;
312
313[[nodiscard]] static bool make_address(int domain, PwValuePtr addr, socklen_t* ss_addr_size, PwValuePtr result)
314/*
315 * Helper function for bind and connect.
316 *
317 * If `addr` is SockAddr, clone it if address family matches domain.
318 * If `addr` is String, try to convert it to SockAddr.
319 *
320 * Write size of sockaddr atructure to `ss_addr_size`.
321 */
322{
323    *ss_addr_size = sizeof(struct sockaddr_storage);
324
325    if (pw_is_sockaddr(addr)) {
326        _PwSockAddrData* sa = _pw_get_struct_ptr(addr, PwTypeId_SockAddr);
327        if (sa->netmask != 0) {
328            pw_set_status(PwStatus(PweHostAddressExpected));
329            return false;
330        }
331        if (sa->addr.ss_family != domain) {
332            pw_set_status(PwStatus(PweAddressFamilyMismatch));
333            return false;
334        }
335        pw_clone2(result, addr);
336        return true;
337    }
338
339    pw_assert(pw_is_string(addr));
340
341    switch (domain) {
342        case AF_LOCAL: {
343            if (!pw_create(PwTypeId_SockAddr, result)) {
344                return false;
345            }
346
347            _PwSockAddrData* sa = _pw_get_struct_ptr(result, PwTypeId_SockAddr);
348            sa->addr.ss_family = AF_LOCAL;
349
350            unsigned len = pw_strlen_in_utf8(addr);
351            if (len >= sizeof(((struct sockaddr_un*)0)->sun_path)) {
352                pw_set_status(PwStatus(PweSocketNameTooLong));
353                return false;
354            }
355            pw_string_to_utf8(addr, ((struct sockaddr_un*) &sa->addr)->sun_path);
356            *ss_addr_size = offsetof(struct sockaddr_un, sun_path) + len + 1;
357            return true;
358        }
359        case AF_INET:
360        case AF_INET6:
361            return pw_parse_inet_address(addr, result);
362
363        default:
364            pw_panic("Address family %d is not supported yet\n", domain);
365    }
366}
367
368static bool socket_bind(PwMethod_Socket_bind* mthis, PwValuePtr self, PwValuePtr local_addr)
369{
370    _PwSocket* s = pw_this_data(self);
371
372    pw_destroy(&s->local_addr);
373
374    socklen_t addr_size;
375    PwValue addr = PW_NULL;
376    if (!make_address(s->domain, local_addr, &addr_size, &addr)) {
377        return false;
378    }
379
380    pw_move(&s->local_addr, &addr);
381    _PwSockAddrData* sa = _pw_get_struct_ptr(&s->local_addr, PwTypeId_SockAddr);
382
383    if (bind(s->sock, (struct sockaddr*) &sa->addr, addr_size) == 0) {
384        return true;
385    } else {
386        pw_set_status(PwErrno(errno));
387        return false;
388    }
389}
390
391static bool socket_reuse_addr(PwMethod_Socket_reuse_addr* mthis, PwValuePtr self, bool reuse)
392{
393    _PwSocket* s = pw_this_data(self);
394
395    int i = reuse;
396    if (setsockopt(s->sock, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i)) < 0) {
397        pw_set_status(PwErrno(errno));
398        return false;
399    } else {
400        return true;
401    }
402}
403
404static bool socket_listen(PwMethod_Socket_listen* mthis, PwValuePtr self, int backlog, uint16_t new_sock_type)
405{
406    _PwSocket* s = pw_this_data(self);
407
408    if (backlog == 0) {
409        backlog = 5;
410    }
411    if (new_sock_type == PwTypeId_Null) {
412        new_sock_type = self->type_id;
413    }
414    if (listen(s->sock, backlog) == -1) {
415        pw_set_status(PwErrno(errno));
416        return false;
417    }
418    s->listen_backlog = backlog;
419    s->new_sock_type = new_sock_type;
420    return true;
421}
422
423static bool socket_is_listening(PwMethod_Socket_is_listening* mthis, PwValuePtr self, bool* result)
424{
425    _PwSocket* s = pw_this_data(self);
426    *result = s->listen_backlog != 0;
427    return true;
428}
429
430static bool socket_accept(PwMethod_Socket_accept* mthis, PwValuePtr self, PwValuePtr result)
431{
432    _PwSocket* s_lsnr = pw_this_data(self);
433
434    // create uninitialized socket
435    if (!pw_create(s_lsnr->new_sock_type, result)) {
436        return false;
437    }
438
439    _PwSocket* s_new  = pw_this_data(result);
440
441    // initialize addresses for new socket
442    pw_clone2(&s_new->local_addr, &s_lsnr->local_addr);
443    if (!pw_create(PwTypeId_SockAddr, &s_new->remote_addr)) {
444        return false;
445    }
446
447    // get pointer to remote address structure
448    _PwSockAddrData* sa_remote = _pw_get_struct_ptr(&s_new->remote_addr, PwTypeId_SockAddr);
449
450    // call accept and initialize new socket and remote address
451    socklen_t addrlen = sizeof(sa_remote->addr);
452    s_new->sock = accept(s_lsnr->sock, (struct sockaddr*) &sa_remote->addr,  &addrlen);
453    if (s_new->sock < 0) {
454        pw_set_status(PwErrno(errno));
455        return false;
456    }
457    // initialize other fields (they are mainly for informational purposes)
458    s_new->domain = s_lsnr->domain;
459    s_new->type   = s_lsnr->type;
460    s_new->proto  = s_lsnr->proto;
461
462    return true;
463}
464
465static bool socket_connect(PwMethod_Socket_connect* mthis, PwValuePtr self, PwValuePtr remote_addr)
466{
467    _PwSocket* s = pw_this_data(self);
468
469    pw_destroy(&s->local_addr);
470
471    socklen_t addr_size;
472    PwValue addr = PW_NULL;
473    if (!make_address(s->domain, remote_addr, &addr_size, &addr)) {
474        return false;
475    }
476
477    pw_move(&s->remote_addr, &addr);
478    _PwSockAddrData* sa = _pw_get_struct_ptr(&s->remote_addr, PwTypeId_SockAddr);
479
480    int rc;
481    do {
482        rc = connect(s->sock, (struct sockaddr*) &sa->addr, sizeof(sa->addr));
483    } while (rc != 0 && errno == EINTR);
484
485    if (rc == 0) {
486        return true;
487    }
488    if (errno == EAGAIN && sa->addr.ss_family == AF_LOCAL) {
489        // make it consistent
490        errno = EINPROGRESS;
491    }
492    pw_set_status(PwErrno(errno));
493    return false;
494}
495
496static bool socket_shutdown(PwMethod_Socket_shutdown* mthis, PwValuePtr self, int how)
497{
498    _PwSocket* s = pw_this_data(self);
499
500    if (shutdown(s->sock, how) == 0) {
501        return true;
502    } else {
503        pw_set_status(PwErrno(errno));
504        return false;
505    }
506}
507
508static bool socket_get_socket_error(PwMethod_Socket_get_socket_error* mthis, PwValuePtr self, PwValuePtr result)
509{
510    _PwSocket* s = pw_this_data(self);
511    int err;
512    socklen_t len = sizeof(int);
513    if (getsockopt(s->sock,  SOL_SOCKET, SO_ERROR, (void*) &err, &len) == -1) {
514        fprintf(stderr, "%s:%d getsockopt: %s\n", __FILE__, __LINE__, strerror(errno));
515        return false;
516    }
517    pw_destroy(result);
518    if (err) {
519        *result = PwErrno(err);
520    } else {
521        *result = PwSuccess();
522    }
523    return true;
524}
525
526static PwInterface_Socket socket_interface = {
527#define X(name, ...) .name = { .func = socket_##name } __VA_OPT__(,)
528    PW_SOCKET_INTERFACE_METHODS
529#undef X
530};
531
532/****************************************************************
533 * Fd interface
534 */
535
536static bool socket_get_fd(PwMethod_Fd_get_fd* mthis, PwValuePtr self, int* result)
537{
538    _PwSocket* s = pw_this_data(self);
539    *result = s->sock;
540    return true;
541}
542
543static bool socket_set_fd(PwMethod_Fd_set_fd* mthis, PwValuePtr self, int fd, bool move)
544{
545    _PwSocket* s = pw_this_data(self);
546
547    if (s->sock != -1) {
548        // fd already set
549        pw_set_status(PwStatus(PweFdAlreadySet));
550        return false;
551    }
552    s->sock = fd;
553    s->own_fd = move;
554    return true;
555}
556
557static bool socket_close(PwMethod_Fd_close* mthis, PwValuePtr self)
558{
559    close_socket(pw_this_data(self));
560    return true;
561}
562
563static bool socket_set_nonblocking(PwMethod_Fd_set_nonblocking* mthis, PwValuePtr self, bool mode)
564{
565    _PwSocket* s = pw_this_data(self);
566
567    int flags = fcntl(s->sock, F_GETFL, 0);
568    if (mode) {
569        flags |= O_NONBLOCK;
570    } else {
571        flags &= ~O_NONBLOCK;
572    }
573    if (fcntl(s->sock, F_SETFL, flags) == -1) {
574        pw_set_status(PwErrno(errno));
575        return false;
576    } else {
577        return true;
578    }
579}
580
581static PwInterface_Fd socket_fd_interface = {
582#define X(name, ...) .name = { .func = socket_##name } __VA_OPT__(,)
583    PW_FD_INTERFACE_METHODS
584#undef X
585};
586
587/****************************************************************
588 * Reader and writer interfaces
589 */
590
591static bool socket_read(PwMethod_Reader_read* mthis, PwValuePtr self, void* buffer, unsigned buffer_size, unsigned* bytes_read)
592{
593    _PwSocket* s = pw_this_data(self);
594
595    ssize_t result;
596    do {
597        result = read(s->sock, buffer, buffer_size);
598    } while (result < 0 && errno == EINTR);
599
600    if (result < 0) {
601        pw_set_status(PwErrno(errno));
602        return false;
603    } else {
604        *bytes_read = (unsigned) result;
605        return true;
606    }
607}
608
609static bool socket_write(PwMethod_Writer_write* mthis, PwValuePtr self, void* data, unsigned size, unsigned* bytes_written)
610{
611    _PwSocket* s = pw_this_data(self);
612
613    ssize_t result;
614    do {
615        result = write(s->sock, data, size);
616    } while (result < 0 && errno == EINTR);
617
618    if (result < 0) {
619        pw_set_status(PwErrno(errno));
620        return false;
621    } else {
622        *bytes_written = (unsigned) result;
623        return true;
624    }
625}
626
627static PwInterface_Reader socket_reader_interface = {
628#define X(name, ...) .name = { .func = socket_##name } __VA_OPT__(,)
629    PW_READER_INTERFACE_METHODS
630#undef X
631};
632
633static PwInterface_Writer socket_writer_interface = {
634#define X(name, ...) .name = { .func = socket_##name } __VA_OPT__(,)
635    PW_WRITER_INTERFACE_METHODS
636#undef X
637};
638
639/****************************************************************
640 * Initialization
641 */
642
643[[gnu::constructor]]
644void _pw_init_socket()
645{
646    if (PwInterfaceId_Socket) {
647        return;
648    }
649
650    _pw_init_types();
651
652#   define X(name, ...) #name __VA_OPT__(,)
653    PwInterfaceId_Socket = pw_register_interface("Socket", PW_SOCKET_INTERFACE_METHODS, nullptr);
654#   undef X
655
656    // init status codes
657    PweBadAddressFamily      = pw_define_status("Bad address family");
658    PweBadIpAddress          = pw_define_status("Bad IP address");
659    PweBadPort               = pw_define_status("Bad port");
660    PweHostAddressExpected   = pw_define_status("Host address expected");
661    PweAddressFamilyMismatch = pw_define_status("Address family mismatch");
662    PweSocketNameTooLong     = pw_define_status("Socket name too long");
663    PweMissingNetmask        = pw_define_status("Missing netmask");
664    PweBadNetmask            = pw_define_status("Bad netmask");
665    PwePortUnspecified       = pw_define_status("Port unspecified");
666
667    // init types
668
669    PwTypeId_SockAddr = pw_add_type2(
670        "SockAddr", _PwSockAddrData,
671        PW_PARENTS,
672            PwTypeId_Struct,
673        PW_INTERFACES,
674            PwInterfaceId_Basic, &sockaddr_basic_interface
675    );
676
677    PwTypeId_Socket = pw_add_type2(
678        "Socket", _PwSocket,
679        PW_PARENTS,
680            PwTypeId_Struct,
681        PW_INTERFACES,
682            PwInterfaceId_Basic,  &socket_basic_interface,
683            PwInterfaceId_Fd,     &socket_fd_interface,
684            PwInterfaceId_Socket, &socket_interface,
685            PwInterfaceId_Reader, &socket_reader_interface,
686            PwInterfaceId_Writer, &socket_writer_interface
687    );
688}