Fix searching key with presented its length
[tedtools.git] / tcp.c
1 /*
2  * Copyright (c) 2004 Teodor Sigaev <teodor@sigaev.ru>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *        notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *        notice, this list of conditions and the following disclaimer in the
12  *        documentation and/or other materials provided with the distribution.
13  * 3. Neither the name of the author nor the names of any co-contributors
14  *        may be used to endorse or promote products derived from this software
15  *        without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY CONTRIBUTORS ``AS IS'' AND ANY EXPRESS
18  * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED. IN NO EVENT SHALL CONTRIBUTORS BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
23  * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
25  * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
26  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
27  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28  */
29 #include <stdio.h>
30 #include <errno.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34 #include <fcntl.h>
35
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #else /* HAVE_POLL */
39 #ifdef HAVE_SYS_POLL_H
40 #include <sys/poll.h>
41 #else
42 #error Not defined HAVE_POLL_H or HAVE_SYS_POLL_H
43 #endif /* HAVE_SYS_POLL_H */
44 #endif /* HAVE_POLL */
45
46 #ifdef HAVE_HSTRERROR 
47 #include <netdb.h>
48 #endif
49
50
51 #include "connection.h"
52 #include "tlog.h"
53 #include "tmalloc.h"
54
55 static u_int32_t
56 setlinger( TC_Connection *cs ) {
57         struct linger ling;
58         int     val = 0;
59         socklen_t size = sizeof(val); 
60
61         if (getsockopt(cs->fd, SOL_SOCKET,SO_ERROR,&val,&size) == -1) {
62                 tlog(TL_ALARM,"getsockopt: %s:%d - %s(%d)",inet_ntoa(cs->serv_addr.sin_addr),
63                         ntohs(cs->serv_addr.sin_port), strerror(errno), errno);
64                 shutdown(cs->fd,SHUT_RDWR);
65                 close(cs->fd);
66                 cs->fd = 0;
67                 cs->state = CS_ERROR;
68                 return CS_ERROR;
69         }
70
71         if ( val ) {
72                 tlog(TL_ALARM,"getsockopt return: %s:%d - %s(%d)",inet_ntoa(cs->serv_addr.sin_addr),
73                         ntohs(cs->serv_addr.sin_port), strerror(val), val);
74                 shutdown(cs->fd,SHUT_RDWR);
75                 close(cs->fd);
76                 cs->fd = 0;
77                 cs->state = CS_ERROR;
78                 return CS_ERROR;
79         }
80
81
82         ling.l_onoff = ling.l_linger = 0;
83         if (setsockopt(cs->fd, SOL_SOCKET,SO_LINGER,(char *)&ling,sizeof(ling))==-1) {
84                 tlog(TL_ALARM,"setsockopt: LINGER %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port),
85                         strerror(errno));
86                 shutdown(cs->fd,SHUT_RDWR);
87                 close(cs->fd);
88                 cs->fd = 0;
89                 cs->state = CS_ERROR;
90                 return CS_ERROR;
91         }
92         cs->state = CS_CONNECTED;
93         return CS_CONNECTED;
94 }
95
96 u_int32_t
97 TC_ClientInitConnection(TC_Connection *cs, char *name, u_int32_t port) {
98         int flags, val=1;
99
100         cs = TC_fillConnection(cs, name, port);
101
102         cs->state = CS_OK;
103         if ((cs->fd= socket(AF_INET, SOCK_STREAM, 0)) < 0)
104                 tlog(TL_CRIT|TL_EXIT,"socket4: %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr),
105                         ntohs(cs->serv_addr.sin_port),strerror(errno));
106
107         if (setsockopt(cs->fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) < 0) {
108                 tlog(TL_CRIT|TL_EXIT, "socketsockopt failed: %s (%d)", strerror(errno), errno);
109                 return CS_ERROR;
110         }
111
112         if ((flags=fcntl(cs->fd,F_GETFL,0)) == -1)
113                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
114         if (fcntl(cs->fd,F_SETFL,flags|O_NDELAY) < 0 )
115                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
116
117         if (bind(cs->fd, (struct sockaddr *) &(cs->serv_addr), sizeof(cs->serv_addr)) < 0)
118                 tlog(TL_CRIT|TL_EXIT, "cannot bind to %s address: %s",
119                         inet_ntoa(cs->serv_addr.sin_addr), strerror(errno));
120         
121         if (listen(cs->fd, 0) < 0)
122                 tlog(TL_CRIT|TL_EXIT, "cannot listen to %s address: %s",
123                         inet_ntoa(cs->serv_addr.sin_addr), strerror(errno));
124         
125         return CS_OK;
126 }
127
128 TC_Connection*
129 TC_AcceptTcp(TC_Connection *cs) {
130         TC_Connection *nc;
131         struct sockaddr_in cli_addr;
132         int ret, flags;
133         socklen_t clilen = sizeof(cli_addr);
134
135         cs->state = CS_READ;
136         if ( (ret = accept(cs->fd,(struct sockaddr *)&cli_addr, &clilen)) < 0 ) {
137                 if ( errno == EAGAIN || errno == EWOULDBLOCK )
138                         return NULL;
139                 tlog(TL_ALARM,"TC_AcceptTcp: accept: %s", strerror(errno));
140                 return NULL;
141         }
142         nc = (TC_Connection*)t0malloc(sizeof(TC_Connection));
143
144         nc->fd = ret;
145         if ((flags=fcntl(nc->fd,F_GETFL,0)) == -1)
146                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
147         if (fcntl(nc->fd,F_SETFL,flags|O_NDELAY) < 0 )
148                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
149         memcpy( &(nc->serv_addr), &cli_addr, clilen );
150         nc->state = CS_CONNECTED;
151
152         setlinger(nc);
153         return nc;
154 }
155
156 TC_Connection *
157 TC_fillConnection(TC_Connection *sc, char *name, u_int32_t port) {
158         if ( !sc ) 
159                 sc = (TC_Connection *)tmalloc(sizeof(TC_Connection));
160         memset(sc, 0, sizeof(TC_Connection));
161         sc->serv_addr.sin_family = AF_INET;
162         sc->serv_addr.sin_addr.s_addr = (name && *name != '*' ) ? inet_addr(name) : htonl(INADDR_ANY);
163         if ( sc->serv_addr.sin_addr.s_addr == INADDR_NONE ) {
164                 struct hostent *host;
165
166                 /*
167                  * Can't parse address: it's a DNS Name
168                  */
169                 host = gethostbyname(name);
170                 if ( host && host->h_addrtype == AF_INET ) {
171                         memcpy(&sc->serv_addr.sin_addr.s_addr, host->h_addr_list[0], 
172                                 sizeof(&sc->serv_addr.sin_addr.s_addr));
173                 } else {
174                         tlog(TL_CRIT,"gethostbyname: %s - %s", name, hstrerror(h_errno));
175                         sc->state = CS_ERROR;
176                         return sc;
177                 }
178         }
179         
180         sc->serv_addr.sin_port = htons(port);
181         sc->state = CS_NOTINITED;
182         return sc; 
183 }
184
185
186 u_int32_t
187 TC_ServerInitConnect( TC_Connection     *cs ) {
188         int flags;
189
190         if ( cs->state == CS_ERROR )
191                 return CS_ERROR;
192
193         if ((cs->fd= socket(AF_INET, SOCK_STREAM, 0)) < 0) {
194                 tlog(TL_CRIT,"socket4: %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr),
195                         ntohs(cs->serv_addr.sin_port),strerror(errno));
196                 cs->state  = CS_ERROR;
197                 return  CS_ERROR;
198         }
199
200         if ((flags=fcntl(cs->fd,F_GETFL,0)) == -1)
201                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
202         if (fcntl(cs->fd,F_SETFL,flags|O_NDELAY) < 0 )
203                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
204
205         flags=1;
206         if (setsockopt(cs->fd, SOL_SOCKET, SO_REUSEADDR, &flags, sizeof(flags)) < 0)
207                 tlog(TL_ALARM, "socketsockopt failed: %s (%d)", strerror(errno), errno);
208
209         if ( connect(cs->fd, (struct sockaddr *) &(cs->serv_addr),
210                 sizeof(struct sockaddr_in)) < 0 ) {
211                 if ( errno == EINPROGRESS || errno == EALREADY ) {
212                         cs->state = CS_INPROCESS;
213                         return CS_INPROCESS; 
214                 } else if (errno != EISCONN && errno != EALREADY &&
215                         errno != EWOULDBLOCK && errno != EAGAIN) {
216                         tlog(TL_DEBUG,"connect: %s:%d - %s",
217                                 inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port),
218                                 strerror(errno));
219                         shutdown(cs->fd,SHUT_RDWR);
220                         close(cs->fd);
221                         cs->fd = 0;
222                 } else {
223                         tlog(TL_DEBUG,"nonblock connect: %s:%d - %s [%d]",
224                                 inet_ntoa(cs->serv_addr.sin_addr),
225                                 ntohs(cs->serv_addr.sin_port),
226                                 strerror(errno),errno);
227                 }
228                 cs->state = CS_ERROR;
229                 return CS_ERROR;
230         }
231
232         cs->state = CS_INPROCESS;
233         return TC_ServerConnect( cs, 0 );
234 }
235         
236
237 u_int32_t
238 TC_ServerConnect( TC_Connection *cs, int timeout ) {
239         struct pollfd   pfd;
240         int ret;
241
242         if ( cs->state != CS_INPROCESS )
243                 return cs->state;
244
245         pfd.fd = cs->fd;
246         pfd.events = POLLOUT;
247         pfd.revents = 0;
248         ret = poll( &pfd, 1, timeout );
249         if ( ret<0 ) {
250                 tlog( TL_CRIT, "TC_ServerConnect: poll: %s",
251                         strerror(errno));
252                 cs->state = CS_ERROR;
253                 return CS_ERROR;
254         } else if ( ret == 0 ) 
255                 return CS_INPROCESS;
256
257         if ( (pfd.revents & (POLLHUP | POLLNVAL | POLLERR)) ) {
258                 tlog( TL_CRIT, "TC_ServerConnect: poll return connect error for %s:%d",
259                         inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port));
260                 cs->state = CS_ERROR;
261                 return CS_ERROR;
262         }
263
264         if ( ! (pfd.revents & POLLOUT) )
265                 return CS_INPROCESS;
266
267
268         return setlinger( cs );
269 }
270
271 int
272 TC_ReadyIO( TC_Connection **cs, int number, int timeout ) {
273         struct pollfd   *pfd;
274         int ret,i, fdnum=0;
275
276         if ( number==0 || cs ==NULL ) {
277                 if (timeout<0)
278                         timeout=1000;
279                 usleep( timeout * 1000.0 );
280                 return 0;
281         }
282         pfd = (struct pollfd*) tmalloc( sizeof(struct pollfd) * number );
283
284         for(i=0; i<number;i++) {
285                 if ( cs[i]->fd>0 && (cs[i]->state == CS_READ || cs[i]->state == CS_SEND) ) {
286                         pfd[fdnum].fd = cs[i]->fd;
287                         pfd[fdnum].events = ( cs[i]->state == CS_READ ) ? POLLIN : POLLOUT;
288                         pfd[fdnum].revents = 0;
289                         fdnum++;
290                 }
291                 cs[i]->readyio=0;
292         }
293
294         if ( fdnum==0 ) {
295                 tfree(pfd);
296                 if (timeout<0)
297                         timeout=1000;
298                 usleep( timeout * 1000.0 );
299                 return 0;
300         }       
301         ret = poll( pfd, fdnum, timeout );
302         if ( ret<0 ) {
303                 tlog( TL_CRIT, "TC_ReadyIO: poll: %s",
304                         strerror(errno));
305                 tfree(pfd);
306                 return 0;
307         }
308
309         if ( ret == 0 ) {
310                 tfree(pfd);
311                 return 0;
312         }
313
314         fdnum=0; ret=0;
315         for(i=0; i<number;i++) {
316                 if ( cs[i]->fd>0 && (cs[i]->state == CS_READ || cs[i]->state == CS_SEND) ) {
317                         if ( pfd[fdnum].revents & (POLLHUP | POLLNVAL | POLLERR) ) { 
318                                 tlog( TL_ALARM, "TC_ReadyIO: poll return error for %s:%d",
319                                         inet_ntoa(cs[i]->serv_addr.sin_addr), 
320                                         ntohs(cs[i]->serv_addr.sin_port));
321                                 cs[i]->state = CS_ERROR;
322                                 ret = 1;
323                         } else if ( pfd[fdnum].revents & ( ( cs[i]->state == CS_READ ) ? POLLIN : POLLOUT ) ) {
324                                 cs[i]->readyio=1;
325                                 ret = 1;
326                         }
327                         fdnum++;
328                 }
329         }
330
331         tfree(pfd);
332         return ret;
333 }
334
335 u_int32_t
336 TC_Send( TC_Connection *cs ) {
337         int sz;
338         
339         if ( cs->state == CS_ERROR )
340                 return CS_ERROR;
341
342         if ( cs->state != CS_SEND || cs->ptr == NULL ) {
343                 cs->state = CS_SEND;
344                 cs->ptr = (char*)cs->buf;
345                 cs->len = cs->buf->len;
346
347                 /* convert fields to network byteorder */
348                 cs->buf->len = htonl(cs->buf->len);
349                 cs->buf->type = htonl(cs->buf->type);
350         }
351
352         if ( cs->ptr - (char*)cs->buf >= cs->len ) {
353                 cs->state = CS_FINISHSEND;
354                 return CS_FINISHSEND;
355         }
356
357         if ((sz=write(cs->fd, cs->ptr, cs->len - (cs->ptr - (char*)cs->buf)))==0 ||
358                 (sz < 0 && (errno == EWOULDBLOCK || errno == EAGAIN))) {
359
360                 /* SunOS 4.1.x, are broken and select() says that
361                  * O_NDELAY sockets are always writable even when
362                  * they're actually not.
363                  */
364                 cs->state = CS_SEND;
365                 return CS_SEND;
366         }
367         if ( sz<0 ) {
368                 if (errno != EPIPE && errno != EINVAL)
369                         tlog(TL_ALARM, "write[%s:%d] - %s",
370                                 inet_ntoa(cs->serv_addr.sin_addr),
371                                 ntohs(cs->serv_addr.sin_port), 
372                                 strerror(errno));
373                 cs->state = CS_ERROR;
374                 return CS_ERROR;
375         }
376
377         cs->ptr += sz;
378
379         if ( cs->ptr - (char*)cs->buf >= cs->len ) {
380                 cs->state = CS_FINISHSEND;
381                 /* revert byteorder conversion */
382                 cs->buf->len = ntohl(cs->buf->len);
383                 cs->buf->type = ntohl(cs->buf->type);
384                 return CS_FINISHSEND;
385         }
386         
387         return cs->state;
388 }
389
390 static void 
391 resizeCS( TC_Connection *cs, int sz ) {
392         int diff = cs->ptr - (char*)cs->buf;
393
394         if ( cs->len >= sz )
395                 return; 
396         cs->len = sz;
397         cs->buf = (TCMsg*)trealloc( (void*)cs->buf, cs->len );
398         cs->ptr = ((char*)cs->buf) + diff;
399 }
400
401 u_int32_t
402 TC_Read( TC_Connection *cs, size_t maxsize ) {
403         int sz, totalread = -1, toread=0, alreadyread;
404
405         if ( cs->state == CS_ERROR )
406                 return CS_ERROR;
407
408         if (cs->state != CS_READ || cs->ptr == NULL ) {
409                 cs->state = CS_READ;
410                 cs->ptr = (char*)cs->buf;
411                 cs->len = 0;
412         }
413
414         alreadyread = cs->ptr - (char*)cs->buf;
415         if ( alreadyread < TCMSGHDRSZ ) {
416                 toread = TCMSGHDRSZ - alreadyread;
417                 resizeCS(cs, TCMSGHDRSZ);
418         } else {
419                 totalread = ntohl(cs->buf->len);
420                 if ( maxsize > 0 && totalread > maxsize )
421                 {
422                         tlog(TL_ALARM,"TC_Read: message size (%d b) is greater than max allowed (%d b)", totalread, maxsize);
423                         cs->state = CS_ERROR;
424                         return CS_ERROR;
425                 }
426                 toread = totalread - alreadyread;
427                 if ( toread == 0 ) {
428                         cs->state = CS_FINISHREAD;
429                         return CS_FINISHREAD;
430                 }
431                 resizeCS(cs, totalread);
432         }
433
434         if ((sz=read( cs->fd, cs->ptr, toread))<0) {
435                 if (errno == EAGAIN || errno == EINTR) {
436                         cs->state = CS_READ;
437                         return CS_READ;
438                 }
439                 tlog(TL_ALARM,"read: finish - %s",strerror(errno));
440                 cs->state = CS_ERROR;
441                 return CS_ERROR;
442         }
443         
444         if ( alreadyread < TCMSGHDRSZ  && alreadyread + sz >= TCMSGHDRSZ ) {
445                 /* 
446                  * we just read header - we can get totalread value. 
447                  */
448                 totalread = ntohl(cs->buf->len);
449         }
450         
451         cs->ptr += sz;
452         alreadyread += sz;
453         if ( sz == 0 && alreadyread != totalread ) {
454                 tlog(TL_ALARM,"read: disconnecting");
455                 cs->state = CS_ERROR;
456                 return CS_ERROR;
457         }
458
459         if ( alreadyread == totalread ) {
460                 cs->buf->len = ntohl(cs->buf->len);
461                 cs->buf->type = ntohl(cs->buf->type);
462                 cs->state = CS_FINISHREAD;
463         }
464
465         return cs->state;
466 }
467
468 void
469 TC_FreeConnection( TC_Connection *cs ) {
470         if ( cs->state == CS_CLOSED )
471                 return;
472         if ( cs->buf ) {
473                 tfree(cs->buf);
474                 cs->buf = NULL;
475         }
476         if ( cs->fd && cs->state != CS_NOTINITED ) {
477                 shutdown(cs->fd,SHUT_RDWR);
478                 close(cs->fd);
479         }
480         cs->fd = 0;
481         cs->state = CS_CLOSED;
482 }
483
484 u_int32_t 
485 TC_Talk( TC_Connection *cs, size_t maxsize  ) {
486         if ( cs->state==CS_NOTINITED ) 
487                 TC_ServerInitConnect( cs );
488
489         while( cs->state == CS_INPROCESS ) 
490                 TC_ServerConnect(cs, 100);
491
492         if ( cs->state != CS_CONNECTED )
493                 return cs->state;
494         
495         cs->state = CS_SEND;
496         cs->ptr = NULL;
497         while( cs->state != CS_FINISHSEND ) {
498                 while( !TC_ReadyIO( &cs, 1, 100) );
499                 if ( TC_Send(cs) == CS_ERROR ) return CS_ERROR;
500         }
501
502         cs->state = CS_READ;
503         cs->ptr = NULL;
504         while( cs->state != CS_FINISHREAD ) {
505                 while( !TC_ReadyIO( &cs, 1, 100) );
506                 if ( TC_Read(cs, maxsize) == CS_ERROR ) return CS_ERROR;
507         }
508
509         return CS_OK; 
510 }
511
512