Initial revision
[tedtools.git] / tcp.c
1 /*
2  * Copyright (c) 2004 Teodor Sigaev <teodor@sigaev.ru>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *        notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *        notice, this list of conditions and the following disclaimer in the
12  *        documentation and/or other materials provided with the distribution.
13  * 3. Neither the name of the author nor the names of any co-contributors
14  *        may be used to endorse or promote products derived from this software
15  *        without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY CONTRIBUTORS ``AS IS'' AND ANY EXPRESS
18  * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED. IN NO EVENT SHALL CONTRIBUTORS BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
23  * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
25  * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
26  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
27  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28  */
29 #include <stdio.h>
30 #include <errno.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34 #include <fcntl.h>
35
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #else /* HAVE_POLL */
39 #ifdef HAVE_SYS_POLL_H
40 #include <sys/poll.h>
41 #else
42 #error Not defined HAVE_POLL_H or HAVE_SYS_POLL_H
43 #endif /* HAVE_SYS_POLL_H */
44 #endif /* HAVE_POLL */
45
46 #ifdef HAVE_HSTRERROR 
47 #include <netdb.h>
48 #endif
49
50
51 #include "connection.h"
52 #include "tlog.h"
53 #include "tmalloc.h"
54
55 u_int32_t
56 TC_ClientInitConnection(TC_Connection *cs, char *name, u_int32_t port) {
57         int flags;
58
59         cs = TC_fillConnection(cs, name, port);
60
61         cs->state = CS_OK;
62         if ((cs->fd= socket(AF_INET, SOCK_STREAM, 0)) < 0)
63                 tlog(TL_CRIT|TL_EXIT,"socket4: %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr),
64                         ntohs(cs->serv_addr.sin_port),strerror(errno));
65
66         if ((flags=fcntl(cs->fd,F_GETFL,0)) == -1)
67                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
68         if (fcntl(cs->fd,F_SETFL,flags|O_NDELAY) < 0 )
69                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
70
71         if (bind(cs->fd, (struct sockaddr *) &(cs->serv_addr), sizeof(cs->serv_addr)) < 0)
72                 tlog(TL_CRIT|TL_EXIT, "cannot bind to %s address: %s",
73                         inet_ntoa(cs->serv_addr.sin_addr), strerror(errno));
74         
75         if (listen(cs->fd, 0) < 0)
76                 tlog(TL_CRIT|TL_EXIT, "cannot listen to %s address: %s",
77                         inet_ntoa(cs->serv_addr.sin_addr), strerror(errno));
78         
79         return CS_OK;
80 }
81
82 TC_Connection*
83 TC_AcceptTcp(TC_Connection *cs) {
84         TC_Connection *nc;
85         struct sockaddr_in cli_addr;
86         int ret, flags;
87         socklen_t clilen = sizeof(cli_addr);
88
89         cs->state = CS_READ;
90         if ( (ret = accept(cs->fd,(struct sockaddr *)&cli_addr, &clilen)) < 0 ) {
91                 if ( errno == EAGAIN || errno == EWOULDBLOCK )
92                         return NULL;
93                 tlog(TL_ALARM,"TC_AcceptTcp: accept: %s", strerror(errno));
94                 return NULL;
95         }
96         nc = (TC_Connection*)t0malloc(sizeof(TC_Connection));
97
98         nc->fd = ret;
99         if ((flags=fcntl(nc->fd,F_GETFL,0)) == -1)
100                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
101         if (fcntl(nc->fd,F_SETFL,flags|O_NDELAY) < 0 )
102                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
103         memcpy( &(nc->serv_addr), &cli_addr, clilen );
104         nc->state = CS_CONNECTED;
105
106         return nc;
107 }
108
109 TC_Connection *
110 TC_fillConnection(TC_Connection *sc, char *name, u_int32_t port) {
111         if ( !sc ) 
112                 sc = (TC_Connection *)t0malloc(sizeof(TC_Connection));
113         sc->serv_addr.sin_family = AF_INET;
114         sc->serv_addr.sin_addr.s_addr = (name) ? inet_addr(name) : htonl(INADDR_ANY);
115         sc->serv_addr.sin_port = htons(port);
116         sc->state = CS_NOTINITED;
117         return sc; 
118 }
119
120 static u_int32_t
121 setlinger( TC_Connection *cs ) {
122         struct linger ling;
123         int     val = 0;
124         socklen_t size = sizeof(val); 
125
126         if (getsockopt(cs->fd, SOL_SOCKET,SO_ERROR,&val,&size) == -1) {
127                 tlog(TL_ALARM,"getsockopt: %s:%d - %s(%d)",inet_ntoa(cs->serv_addr.sin_addr),
128                         ntohs(cs->serv_addr.sin_port), strerror(errno), errno);
129                 shutdown(cs->fd,SHUT_RDWR);
130                 close(cs->fd);
131                 cs->fd = 0;
132                 cs->state = CS_ERROR;
133                 return CS_ERROR;
134         }
135
136         if ( val ) {
137                 tlog(TL_ALARM,"getsockopt return: %s:%d - %s(%d)",inet_ntoa(cs->serv_addr.sin_addr),
138                         ntohs(cs->serv_addr.sin_port), strerror(val), val);
139                 shutdown(cs->fd,SHUT_RDWR);
140                 close(cs->fd);
141                 cs->fd = 0;
142                 cs->state = CS_ERROR;
143                 return CS_ERROR;
144         }
145
146
147         ling.l_onoff = ling.l_linger = 0;
148         if (setsockopt(cs->fd, SOL_SOCKET,SO_LINGER,(char *)&ling,sizeof(ling))==-1) {
149                 tlog(TL_ALARM,"setsockopt: LINGER %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr),
150                         strerror(errno));
151                 shutdown(cs->fd,SHUT_RDWR);
152                 close(cs->fd);
153                 cs->fd = 0;
154                 cs->state = CS_ERROR;
155                 return CS_ERROR;
156         }
157         cs->state = CS_CONNECTED;
158         return CS_CONNECTED;
159 }
160
161 u_int32_t
162 TC_ServerInitConnect( TC_Connection     *cs ) {
163         int flags;
164
165         if ( cs->state == CS_ERROR )
166                 return CS_ERROR;
167
168         if ((cs->fd= socket(AF_INET, SOCK_STREAM, 0)) < 0) {
169                 tlog(TL_CRIT,"socket4: %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr),
170                         ntohs(cs->serv_addr.sin_port),strerror(errno));
171                 cs->state  = CS_ERROR;
172                 return  CS_ERROR;
173         }
174
175         if ((flags=fcntl(cs->fd,F_GETFL,0)) == -1)
176                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
177         if (fcntl(cs->fd,F_SETFL,flags|O_NDELAY) < 0 )
178                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
179
180         if ( connect(cs->fd, (struct sockaddr *) &(cs->serv_addr),
181                 sizeof(struct sockaddr_in)) < 0 ) {
182                 if ( errno == EINPROGRESS || errno == EALREADY ) {
183                         cs->state = CS_INPROCESS;
184                         return CS_INPROCESS; 
185                 } else if (errno != EISCONN && errno != EALREADY &&
186                         errno != EWOULDBLOCK && errno != EAGAIN) {
187                         tlog(TL_DEBUG,"open4: %s:%d - %s",
188                                 inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port),
189                                 strerror(errno));
190                         shutdown(cs->fd,SHUT_RDWR);
191                         close(cs->fd);
192                         cs->fd = 0;
193                 } else {
194                         tlog(TL_DEBUG,"nonblock connect: %s:%d - %s [%d]",
195                                 inet_ntoa(cs->serv_addr.sin_addr),
196                                 ntohs(cs->serv_addr.sin_port),
197                                 strerror(errno),errno);
198                 }
199                 cs->state = CS_ERROR;
200                 return CS_ERROR;
201         }
202
203         cs->state = CS_INPROCESS;
204         return TC_ServerConnect( cs );
205 }
206         
207
208 u_int32_t
209 TC_ServerConnect( TC_Connection *cs ) {
210         struct pollfd   pfd;
211         int ret;
212
213         if ( cs->state != CS_INPROCESS )
214                 return cs->state;
215
216         pfd.fd = cs->fd;
217         pfd.events = POLLOUT;
218         pfd.revents = 0;
219         ret = poll( &pfd, 1, 0 );
220         if ( ret<0 ) {
221                 tlog( TL_CRIT, "TC_ServerConnect: poll: %s",
222                         strerror(errno));
223                 cs->state = CS_ERROR;
224                 return CS_ERROR;
225         } else if ( ret == 0 ) 
226                 return CS_INPROCESS;
227
228         if ( (pfd.revents & (POLLHUP | POLLNVAL | POLLERR)) ) {
229                 tlog( TL_CRIT, "TC_ServerConnect: poll return connect error for %s:%d",
230                         inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port));
231                 cs->state = CS_ERROR;
232                 return CS_ERROR;
233         }
234
235         if ( ! (pfd.revents & POLLOUT) )
236                 return CS_INPROCESS;
237
238
239         return setlinger( cs );
240 }
241
242 int
243 TC_ReadyIO( TC_Connection **cs, int number, int timeout ) {
244         struct pollfd   *pfd;
245         int ret,i, fdnum=0;
246
247         if ( number==0 || cs ==NULL ) {
248                 usleep( timeout * 1000.0 );
249                 return 0;
250         }
251         pfd = (struct pollfd*) tmalloc( sizeof(struct pollfd) * number );
252
253         for(i=0; i<number;i++) {
254                 if ( cs[i]->fd>0 && (cs[i]->state == CS_READ || cs[i]->state == CS_SEND) ) {
255                         pfd[fdnum].fd = cs[i]->fd;
256                         pfd[fdnum].events = ( cs[i]->state == CS_READ ) ? POLLIN : POLLOUT;
257                         pfd[fdnum].revents = 0;
258                         fdnum++;
259                 }
260                 cs[i]->readyio=0;
261         }
262         ret = poll( pfd, fdnum, timeout );
263         if ( ret<0 ) {
264                 tlog( TL_CRIT, "TC_ReadyIO: poll: %s",
265                         strerror(errno));
266                 tfree(pfd);
267                 return 0;
268         }
269
270         if ( ret == 0 ) {
271                 tfree(pfd);
272                 return 0;
273         }
274
275         fdnum=0; ret=0;
276         for(i=0; i<number;i++) {
277                 if ( cs[i]->fd>0 && (cs[i]->state == CS_READ || cs[i]->state == CS_SEND) ) {
278                         if ( pfd[fdnum].revents & (POLLHUP | POLLNVAL | POLLERR) ) { 
279                                 tlog( TL_ALARM, "TC_ReadyIO: poll return error for %s:%d",
280                                         inet_ntoa(cs[i]->serv_addr.sin_addr), 
281                                         ntohs(cs[i]->serv_addr.sin_port));
282                                 cs[i]->state = CS_ERROR;
283                                 ret = 1;
284                         } else if ( pfd[fdnum].revents & ( ( cs[i]->state == CS_READ ) ? POLLIN : POLLOUT ) ) {
285                                 cs[i]->readyio=1;
286                                 ret = 1;
287                         }
288                         fdnum++;
289                 }
290         }
291
292         tfree(pfd);
293         return ret;
294 }
295
296 u_int32_t
297 TC_Send( TC_Connection *cs ) {
298         int sz;
299         
300         if ( cs->state == CS_ERROR )
301                 return CS_ERROR;
302
303         if ( cs->state != CS_SEND ) {
304                 cs->state = CS_SEND;
305                 cs->ptr = cs->buf;
306         }
307
308         if ( cs->ptr - cs->buf >= cs->len ) {
309                 cs->state = CS_FINISHSEND;
310                 return CS_FINISHSEND;
311         }
312
313         if ((sz=write(cs->fd, cs->ptr, cs->len - (cs->ptr - cs->buf)))==0 ||
314                 (sz < 0 && (errno == EWOULDBLOCK || errno == EAGAIN))) {
315
316                 /* SunOS 4.1.x, are broken and select() says that
317                  * O_NDELAY sockets are always writable even when
318                  * they're actually not.
319                  */
320                 cs->state = CS_SEND;
321                 return CS_SEND;
322         }
323         if ( sz<0 ) {
324                 if (errno != EPIPE && errno != EINVAL)
325                         tlog(TL_ALARM, "write[%s:%d] - %s",
326                                 inet_ntoa(cs->serv_addr.sin_addr),
327                                 ntohs(cs->serv_addr.sin_port), 
328                                 strerror(errno));
329                 cs->state = CS_ERROR;
330                 return CS_ERROR;
331         }
332
333         cs->ptr += sz;
334
335         if ( cs->ptr - cs->buf >= cs->len ) {
336                 cs->state = CS_FINISHSEND;
337                 return CS_FINISHSEND;
338         }
339         
340         return cs->state;
341 }
342
343 static void 
344 resizeCS( TC_Connection *cs, int sz ) {
345         int diff = cs->ptr - cs->buf;
346         if ( cs->len >= sz )
347                 return; 
348         cs->len = sz;
349         cs->buf = (char*)trealloc( (void*)cs->buf, cs->len );
350         cs->ptr = cs->buf + diff;
351 }
352
353 u_int32_t
354 TC_Read( TC_Connection *cs ) {
355         int sz, totalread = -1, toread=0, alreadyread;
356
357         if ( cs->state == CS_ERROR )
358                 return CS_ERROR;
359
360         if (cs->state != CS_READ ) {
361                 cs->state = CS_READ;
362                 cs->ptr = cs->buf;
363         }
364
365         alreadyread = cs->ptr - cs->buf;
366         if ( alreadyread < sizeof(u_int32_t) ) {
367                 toread = sizeof(u_int32_t) - alreadyread;
368                 resizeCS(cs, sizeof(u_int32_t));
369         } else {
370                 totalread = *(u_int32_t*)(cs->buf);
371                 toread = totalread - alreadyread;
372                 if ( toread == 0 ) {
373                         cs->state = CS_FINISHREAD;
374                         return CS_FINISHREAD;
375                 }
376                 resizeCS(cs, totalread);
377         }
378
379         if ((sz=read( cs->fd, cs->ptr, toread))<0) {
380                 if (errno == EAGAIN || errno == EINTR) {
381                         cs->state = CS_READ;
382                         return CS_READ;
383                 }
384                 tlog(TL_ALARM,"read: finish - %s",strerror(errno));
385                 cs->state = CS_ERROR;
386                 return CS_ERROR;
387         }
388         
389
390         cs->ptr += sz;
391         alreadyread += sz;
392         if ( sz == 0 && alreadyread != totalread ) {
393                 tlog(TL_ALARM,"read: disconnecting");
394                 cs->state = CS_ERROR;
395                 return CS_ERROR;
396         }
397         cs->state = ( alreadyread == totalread ) ? CS_FINISHREAD : CS_READ;
398         return cs->state;
399 }
400
401 void
402 TC_FreeConnection( TC_Connection *cs ) {
403         if ( cs->state == CS_CLOSED )
404                 return;
405         if ( cs->buf ) {
406                 tfree(cs->buf);
407                 cs->buf = NULL;
408         }
409         if ( cs->fd && cs->state != CS_NOTINITED ) {
410                 shutdown(cs->fd,SHUT_RDWR);
411                 close(cs->fd);
412         }
413         cs->fd = 0;
414         cs->state = CS_CLOSED;
415 }
416
417 u_int32_t 
418 TC_Talk( TC_Connection *cs ) {
419         u_int32_t ret = TC_ServerInitConnect( cs );
420
421         while( ret == CS_INPROCESS ) {
422                 ret =  TC_ServerConnect(cs);
423         }
424
425         if ( ret != CS_CONNECTED )
426                 return ret;
427         
428         while( ret != CS_FINISHSEND ) {
429                 ret = TC_Send(cs);
430                 if ( ret == CS_ERROR ) return ret;
431         }
432
433         cs->state = CS_READ;
434         cs->ptr = cs->buf;
435         while( cs->state != CS_FINISHREAD ) {
436                 while( !TC_ReadyIO( &cs, 1, 100) );
437                 if ( ret == CS_ERROR ) return ret;
438                 if ( TC_Read(cs) == CS_ERROR ) return CS_ERROR;
439         }
440
441         return CS_OK; 
442 }
443
444