Implemented swift_connect function
[swifty.git] / src / kernel / swift.c
1 #include <linux/module.h>
2 #include <net/sock.h>
3 #include <net/protocol.h>
4 #include <net/ip.h>
5 #include <net/route.h>
6
7 #include "swift.h"
8
9 MODULE_DESCRIPTION("Swift Transport Protocol");
10 MODULE_AUTHOR("Adrian Bondrescu/Cornel Mercan");
11 MODULE_LICENSE("GPL");
12
13 struct swift_sock {
14         struct inet_sock sock;
15         /* swift socket speciffic data */
16         __be16 src;
17         __be16 dst;
18 };
19
20 static struct swift_sock * sock_port_map[MAX_SWIFT_PORT];
21
22 static inline struct swift_sock * swift_sk(struct sock * sock)
23 {
24         return (struct swift_sock *)(sock);
25 }
26
27 static inline struct swifthdr * swift_hdr(const struct sk_buff * skb)
28 {
29         return (struct swifthdr *) skb_transport_header(skb);
30 }
31
32 static inline __be16 get_next_free_port(void)
33 {
34         int i;
35         for (i = MIN_SWIFT_PORT; i < MAX_SWIFT_PORT; i ++)
36                 if (sock_port_map[i] == NULL)
37                         return i;
38         return 0;
39 }
40
41 static inline void swift_unhash(__be16 port)
42 {
43         sock_port_map[port] = NULL;
44 }
45
46 static inline void swift_hash(__be16 port, struct swift_sock *ssh)
47 {
48         sock_port_map[port] = ssh;
49 }
50
51 static inline struct swift_sock * swift_lookup(__be16 port)
52 {
53         return sock_port_map[port];
54 }
55
56 static int swift_release(struct socket *sock)
57 {
58         struct sock *sk = sock->sk;
59         struct swift_sock * ssk = swift_sk(sk);
60
61         if (!sk)
62                 return 0;
63
64         swift_unhash(ssk->src);
65         
66         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
67
68         synchronize_net();
69
70         sock_orphan(sk);
71         sock->sk = NULL;
72
73         skb_queue_purge(&sk->sk_receive_queue);
74
75         printk(KERN_DEBUG "swift_release sock=%p\n", sk);
76         sock_put(sk);
77
78         return 0;
79 }
80
81 static int swift_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
82 {
83         struct sockaddr_swift *swift_addr;
84         struct swift_sock *ssk;
85         int err;
86         __be16 port;
87
88         err = -EINVAL;
89         if (addr_len < sizeof(struct sockaddr_swift)) {
90                 printk(KERN_ERR "Invalid size for sockaddr\n");
91                 goto out;
92         }
93
94         swift_addr = (struct sockaddr_swift *) addr;
95
96         err = -EINVAL;
97         if (swift_addr->sin_family != AF_INET) {
98                 printk(KERN_ERR "Invalid family for sockaddr\n");
99                 goto out;
100         }
101
102         port = ntohs(swift_addr->sin_port);
103
104         err = -EINVAL;
105         if (port == 0 || port >= MAX_SWIFT_PORT) {
106                 printk(KERN_ERR "Invalid value for sockaddr port (%u)\n", port);
107                 goto out;
108         }
109         
110         err = -EADDRINUSE;
111         if (swift_lookup(port) != NULL) {
112                 printk(KERN_ERR "Port %u already in use\n", port);
113                 goto out;
114         }
115
116         ssk = swift_sk(sock->sk);
117         ssk->src = port;
118
119         swift_hash(port, ssk);
120
121         printk(KERN_DEBUG "Socket %p bound to port %u\n", ssk, port);
122         
123         return 0;
124
125 out:
126         return -EINVAL;
127 }
128
129 static int swift_connect(struct socket *sock, struct sockaddr *addr, int addr_len, int flags)
130 {
131         int err;
132         struct sock * sk; 
133         struct inet_sock * isk;
134         struct swift_sock * ssk;
135
136         printk(KERN_DEBUG "swift_connect\n");
137
138         err = -EINVAL;
139         if (sock == NULL) {
140                 printk(KERN_ERR "Sock is NULL\n");
141                 goto out;
142         }
143         sk = sock->sk;
144
145         err = -EINVAL;
146         if (sk == NULL) {
147                 printk(KERN_ERR "Sock->sk is NULL\n");
148                 goto out;
149         }
150         
151         isk = inet_sk(sk);
152         ssk = swift_sk(sk);
153
154         if (ssk->src != 0) {
155                 printk(KERN_ERR "ssk->src is not NULL\n");
156                 goto out;
157         }
158         
159         err = -EINVAL;
160         if (addr) {
161                 struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) addr;
162                 
163                 err = -EINVAL;
164                 if (addr_len < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET) {
165                         printk(KERN_ERR "Invalid size or address family\n");
166                         goto out;
167                 }
168                 ssk->dst = ntohs(swift_addr->sin_port);
169                 if (ssk->dst == 0 || ssk->dst >= MAX_SWIFT_PORT) {
170                         printk(KERN_ERR "Invalid value for destination port(%u)\n", ssk->dst);
171                         goto out;
172                 }       
173         
174                 isk->inet_daddr = swift_addr->sin_addr.s_addr;
175                 printk(KERN_DEBUG "Received from user space destination port=%u and address=%u\n", ssk->dst, isk->inet_daddr);
176         } else {
177                 printk(KERN_ERR "Invalid swift_addr (NULL)\n");
178                 goto out;
179         }
180         
181         err = -ENOMEM;
182         ssk->src = get_next_free_port();
183         if (ssk->src == 0) {
184                 printk(KERN_ERR "No free ports\n");
185                 goto out;
186         }
187         
188         swift_hash(ssk->src, ssk);
189
190         return 0;
191
192 out:
193         return err;
194
195 }
196
197 static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len)
198 {
199         int err;
200         __be16 dport;
201         __be32 daddr;
202         __be16 sport;
203         struct sk_buff * skb;
204         struct sock * sk; 
205         struct inet_sock * isk;
206         struct swift_sock * ssk;
207         struct swifthdr * shdr;
208         int connected = 0;
209         int totlen;
210         struct rtable * rt = NULL;
211         
212         err = -EINVAL;
213         if (sock == NULL) {
214                 printk(KERN_ERR "Sock is NULL\n");
215                 goto out;
216         }
217         sk = sock->sk;
218
219         err = -EINVAL;
220         if (sk == NULL) {
221                 printk(KERN_ERR "Sock->sk is NULL\n");
222                 goto out;
223         }
224
225         isk = inet_sk(sk);
226         ssk = swift_sk(sk);
227
228         sport = ssk->src;
229         if (sport == 0) {
230                 err = -ENOMEM;
231                 sport = get_next_free_port();
232                 if (sport == 0) {
233                         printk(KERN_ERR "No free ports\n");
234                         goto out;
235                 }
236         }
237
238         if (msg->msg_name) {
239                 struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) msg->msg_name;
240                 
241                 err = -EINVAL;
242                 if (msg->msg_namelen < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET) {
243                         printk(KERN_ERR "Invalid size or address family\n");
244                         goto out;
245                 }
246                 
247                 dport = ntohs(swift_addr->sin_port);
248                 if (dport == 0 || dport >= MAX_SWIFT_PORT) {
249                         printk(KERN_ERR "Invalid value for destination port(%u)\n", dport);
250                         goto out;
251                 }       
252
253                 daddr = swift_addr->sin_addr.s_addr;
254                 printk(KERN_DEBUG "Received from user space destination port=%u and address=%u\n", dport, daddr);
255         } else {
256                 err = -EDESTADDRREQ;
257                 if (!ssk->dst || !isk->inet_daddr) {
258                         printk(KERN_ERR "No destination port/address\n");
259                         goto out;
260                 }
261                 dport = ssk->dst;
262                 daddr = isk->inet_daddr;
263
264                 printk(KERN_DEBUG "Got from socket destination port=%u and address=%u\n", dport, daddr);
265                 connected = 1;
266         }
267
268         totlen = len + sizeof(struct swifthdr) + sizeof(struct iphdr);
269         skb = sock_alloc_send_skb(sk, totlen, msg->msg_flags & MSG_DONTWAIT, &err);
270         if (!skb) {
271                 printk(KERN_ERR "sock_alloc_send_skb failed\n");
272                 goto out;
273         }
274         printk(KERN_DEBUG "Allocated %u bytes for skb (payload size=%u)\n", totlen, len);
275
276         skb_reset_network_header(skb);
277         skb_reserve(skb, sizeof(struct iphdr));
278         printk(KERN_DEBUG "Reseted network header\n");
279         skb_reset_transport_header(skb);
280         skb_put(skb, sizeof(struct swifthdr));
281         printk(KERN_DEBUG "Reseted transport header\n");
282
283         shdr = (struct swifthdr *) skb_transport_header(skb);
284         shdr->dst = ntohs(dport);
285         shdr->src = ntohs(sport);
286         shdr->len = ntohs(len + sizeof(struct swifthdr));
287
288         printk(KERN_DEBUG "payload=%p\n", skb_put(skb, len));
289
290         err = skb_copy_datagram_from_iovec(skb, sizeof(struct swifthdr), msg->msg_iov, 0, len);
291         if (err) {
292                 printk(KERN_ERR "skb_copy_datagram_from_iovec failed\n");
293                 goto out_free;
294         }
295         printk(KERN_DEBUG "Copied %u bytes into the skb\n", len);
296
297         if (connected)
298                 rt = (struct rtable *) __sk_dst_check(sk, 0);
299
300         if (rt == NULL) {
301                 struct flowi fl = { .fl4_dst = daddr,
302                                     .proto = sk->sk_protocol,
303                                     .flags = inet_sk_flowi_flags(sk),
304                                   };
305                 err = ip_route_output_flow(sock_net(sk), &rt, &fl, sk, 0);
306                 if (err) {
307                         printk(KERN_ERR "Route lookup failed\n");
308                         goto out_free;
309                 }
310                 sk_dst_set(sk, dst_clone(&rt->dst));
311         }
312         
313         err = ip_queue_xmit(skb);
314         if (!err)
315                 printk(KERN_DEBUG "Sent %u bytes on wire\n", len);
316         else
317                 printk(KERN_ERR "ip_queue_xmit failed\n");
318
319         return err;
320
321 out_free:
322         kfree(skb);
323
324 out:
325         return err;
326 }
327
328 static int swift_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len, int flags)
329 {
330         struct sk_buff *skb;
331         struct sockaddr_swift *swift_addr;
332         struct sock * sk = sock->sk;
333         int err, copied;
334
335         skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &err);
336         if (!skb) {
337                 printk(KERN_ERR "skb_recv_datagram\n");
338                 goto out;
339         }
340
341         printk(KERN_DEBUG "Received skb %p\n", skb);
342
343         swift_addr = (struct sockaddr_swift *) skb->cb;
344         msg->msg_namelen = sizeof(struct sockaddr_swift);
345
346         copied = skb->len;
347         if (copied > len) {
348                 copied = len;
349                 msg->msg_flags |= MSG_TRUNC;
350         }
351
352         err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, copied);
353         if (err) {
354                 printk(KERN_ERR "skb_copy_datagram_iovec\n");
355                 goto out_free;
356         }
357
358         sock_recv_ts_and_drops(msg, sk, skb);
359
360         if (msg->msg_name)
361                 memcpy(msg->msg_name, swift_addr, msg->msg_namelen);
362         
363         err = copied;
364
365 out_free:
366         skb_free_datagram(sk, skb);
367
368 out:
369         return err;
370 }
371
372 static int swift_rcv(struct sk_buff *skb)
373 {
374         struct swifthdr *shdr;
375         struct swift_sock *ssk;
376         __be16 len;
377         __be16 src, dst;
378         struct sockaddr_swift * swift_addr;
379         int err;
380
381         if (!pskb_may_pull(skb, sizeof(struct swifthdr))) {
382                 printk(KERN_ERR "Insufficient space for header\n");
383                 goto drop;
384         }
385         
386         shdr = (struct swifthdr *) skb->data;
387         len = ntohs(shdr->len);
388
389         if (skb->len < len) {
390                 printk(KERN_ERR "Malformed packet (packet_len=%u, skb_len=%u)\n", len, skb->len);
391                 goto drop;
392         }
393
394         if (len < sizeof(struct swifthdr)) {
395                 printk(KERN_ERR "Malformed packet (packet_len=%u sizeof(swifthdr)=%u\n", len, sizeof(struct swifthdr));
396                 goto drop;
397         }
398         
399         src = ntohs(shdr->src);
400         dst = ntohs(shdr->dst);
401         if (src == 0 || dst == 0 || src >= MAX_SWIFT_PORT || dst >= MAX_SWIFT_PORT) {
402                 printk(KERN_ERR "Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst);
403                 goto drop;
404         }
405
406         skb_pull(skb, sizeof(struct swifthdr));
407         len -= sizeof(struct swifthdr);
408
409         pskb_trim(skb, len);
410
411         printk(KERN_DEBUG "Received %u bytes from from port=%u to port=%u\n", len - sizeof(struct swifthdr), src, dst);
412
413         ssk = swift_lookup(dst); 
414         if (ssk == NULL) {
415                 printk(KERN_ERR "Swift lookup failed for port %u\n", dst);
416                 goto drop;
417         }
418
419         BUILD_BUG_ON(sizeof(struct sockaddr_swift) > sizeof(skb->cb));
420         
421         swift_addr = (struct sockaddr_swift *) skb->cb;
422         swift_addr->sin_family = AF_INET;
423         swift_addr->sin_port = shdr->src;
424         swift_addr->sin_addr.s_addr = ip_hdr(skb)->saddr;
425
426         printk(KERN_DEBUG "Setting sin_port=%u, sin_addr=%u\n", ntohs(shdr->src), swift_addr->sin_addr.s_addr);
427
428         err = ip_queue_rcv_skb((struct sock *) &ssk->sock, skb);
429         if (err) {
430                 printk(KERN_ERR "ip_queu_rcv_skb\n");
431                 consume_skb(skb);
432         }
433         return NET_RX_SUCCESS;
434
435 drop:
436         kfree(skb);
437         return NET_RX_DROP;
438 }
439
440 static struct proto swift_prot = {
441         .obj_size = sizeof(struct swift_sock),
442         .owner    = THIS_MODULE,
443         .name     = "SWIFT",
444 };
445
446 static const struct proto_ops swift_ops = {
447         .family     = PF_INET,
448         .owner      = THIS_MODULE,
449         .release    = swift_release,
450         .bind       = swift_bind,
451         .connect    = swift_connect,
452         .socketpair = sock_no_socketpair,
453         .accept     = sock_no_accept,
454         .getname    = sock_no_getname,
455         .poll       = datagram_poll,
456         .ioctl      = sock_no_ioctl,
457         .listen     = sock_no_listen,
458         .shutdown   = sock_no_shutdown,
459         .setsockopt = sock_no_setsockopt,
460         .getsockopt = sock_no_getsockopt,
461         .sendmsg    = swift_sendmsg,
462         .recvmsg    = swift_recvmsg,
463         .mmap       = sock_no_mmap,
464         .sendpage   = sock_no_sendpage,
465 };
466
467 static const struct net_protocol swift_protocol = {
468         .handler   = swift_rcv,
469         .no_policy = 1,
470         .netns_ok  = 1,
471 };
472
473 static struct inet_protosw swift_protosw = {
474         .type     = SOCK_DGRAM,
475         .protocol = IPPROTO_SWIFT,
476         .prot     = &swift_prot,
477         .ops      = &swift_ops,
478         .no_check = 0,
479 };
480
481 static int __init swift_init(void)
482 {
483         int rc;
484
485         rc = proto_register(&swift_prot, 1);
486         if (rc) {
487                 printk(KERN_ERR "Error registering swift protocol\n");
488                 goto out;
489         }
490
491         rc = inet_add_protocol(&swift_protocol, IPPROTO_SWIFT);
492         if (rc) {
493                 printk(KERN_ERR "Error adding swift protocol\n");
494                 goto out_unregister;
495         }
496
497         inet_register_protosw(&swift_protosw);
498         printk(KERN_DEBUG "Swift entered\n");
499
500         return 0;
501
502 out_unregister:
503         proto_unregister(&swift_prot);
504
505 out:
506         return rc;
507 }
508
509 static void __exit swift_exit(void)
510 {
511         inet_unregister_protosw(&swift_protosw);
512
513         inet_del_protocol(&swift_protocol, IPPROTO_SWIFT);
514
515         proto_unregister(&swift_prot);
516
517         printk(KERN_DEBUG "Swift exited\n");
518 }
519
520 module_init(swift_init);
521 module_exit(swift_exit);