Fix compiler warnings
[smlar.git] / smlar.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/genam.h"
5 #include "access/heapam.h"
6 #include "access/htup_details.h"
7 #include "access/nbtree.h"
8 #include "catalog/indexing.h"
9 #include "catalog/pg_am.h"
10 #include "catalog/pg_amproc.h"
11 #include "catalog/pg_cast.h"
12 #include "catalog/pg_opclass.h"
13 #include "catalog/pg_type.h"
14 #include "executor/spi.h"
15 #include "utils/catcache.h"
16 #include "utils/fmgroids.h"
17 #include "utils/lsyscache.h"
18 #include "utils/memutils.h"
19 #if (PG_VERSION_NUM < 120000)
20 #include "utils/tqual.h"
21 #endif
22 #include "utils/syscache.h"
23 #include "utils/typcache.h"
24
25 PG_MODULE_MAGIC;
26
27 #if (PG_VERSION_NUM >= 90400)
28 #define SNAPSHOT NULL
29 #else
30 #define SNAPSHOT SnapshotNow
31 #endif
32
33 #if PG_VERSION_NUM >= 130000
34 #define heap_open(r, l)                 table_open((r), (l))
35 #define heap_close(r, l)                table_close((r), (l))
36 #endif
37
38 static Oid
39 getDefaultOpclass(Oid amoid, Oid typid)
40 {
41         ScanKeyData     skey;
42         SysScanDesc     scan;
43         HeapTuple       tuple;
44         Relation        heapRel;
45         Oid                     opclassOid = InvalidOid;
46
47         heapRel = heap_open(OperatorClassRelationId, AccessShareLock);
48
49         ScanKeyInit(&skey,
50                                 Anum_pg_opclass_opcmethod,
51                                 BTEqualStrategyNumber,  F_OIDEQ,
52                                 ObjectIdGetDatum(amoid));
53
54         scan = systable_beginscan(heapRel,
55                                                                 OpclassAmNameNspIndexId, true,
56                                                                 SNAPSHOT, 1, &skey);
57
58         while (HeapTupleIsValid((tuple = systable_getnext(scan))))
59         {
60                 Form_pg_opclass opclass = (Form_pg_opclass)GETSTRUCT(tuple);
61
62                 if ( opclass->opcintype == typid && opclass->opcdefault )
63                 {
64                         if ( OidIsValid(opclassOid) )
65                                 elog(ERROR, "Ambiguous opclass for type %u (access method %u)", typid, amoid); 
66 #if (PG_VERSION_NUM >= 120000)
67                         opclassOid = opclass->oid;
68 #else
69                         opclassOid = HeapTupleGetOid(tuple);
70 #endif
71                 }
72         }
73
74         systable_endscan(scan);
75         heap_close(heapRel, AccessShareLock);
76
77         return opclassOid;
78 }
79
80 static Oid
81 getAMProc(Oid amoid, Oid typid)
82 {
83         Oid             opclassOid = getDefaultOpclass(amoid, typid);
84         Oid             procOid = InvalidOid;
85         Oid             opfamilyOid;
86         ScanKeyData     skey[4];
87         SysScanDesc     scan;
88         HeapTuple       tuple;
89         Relation        heapRel;
90
91         if ( !OidIsValid(opclassOid) )
92         {
93                 typid = getBaseType(typid);
94                 opclassOid = getDefaultOpclass(amoid, typid);
95         }
96
97         if ( !OidIsValid(opclassOid) )
98         {
99                 CatCList        *catlist;
100                 int                     i;
101
102                 /*
103                  * Search binary-coercible type
104                  */
105 #ifdef SearchSysCacheList1
106                 catlist = SearchSysCacheList1(CASTSOURCETARGET,
107                                                                           ObjectIdGetDatum(typid));
108 #else
109                 catlist = SearchSysCacheList(CASTSOURCETARGET, 1,
110                                                                                 ObjectIdGetDatum(typid),
111                                                                                 0, 0, 0);
112 #endif
113
114                 for (i = 0; i < catlist->n_members; i++)
115                 {
116                         HeapTuple               tuple = &catlist->members[i]->tuple;
117                         Form_pg_cast    castForm = (Form_pg_cast)GETSTRUCT(tuple);
118
119                         if ( castForm->castfunc == InvalidOid && castForm->castcontext == COERCION_CODE_IMPLICIT )
120                         {
121                                 typid = castForm->casttarget;
122                                 opclassOid = getDefaultOpclass(amoid, typid);
123                                 if( OidIsValid(opclassOid) )
124                                         break;
125                         }
126                 }
127
128                 ReleaseSysCacheList(catlist);
129         }
130
131         if ( !OidIsValid(opclassOid) )
132                 return InvalidOid;
133
134         opfamilyOid = get_opclass_family(opclassOid);
135
136         heapRel = heap_open(AccessMethodProcedureRelationId, AccessShareLock);
137         ScanKeyInit(&skey[0],
138                                 Anum_pg_amproc_amprocfamily,
139                                 BTEqualStrategyNumber, F_OIDEQ,
140                                 ObjectIdGetDatum(opfamilyOid));
141         ScanKeyInit(&skey[1],
142                                 Anum_pg_amproc_amproclefttype,
143                                 BTEqualStrategyNumber, F_OIDEQ,
144                                 ObjectIdGetDatum(typid));
145         ScanKeyInit(&skey[2],
146                                 Anum_pg_amproc_amprocrighttype,
147                                 BTEqualStrategyNumber, F_OIDEQ,
148                                 ObjectIdGetDatum(typid));
149 #if PG_VERSION_NUM >= 90200
150         ScanKeyInit(&skey[3],
151                                 Anum_pg_amproc_amprocnum,
152                                 BTEqualStrategyNumber, F_OIDEQ,
153                                 Int32GetDatum(BTORDER_PROC));
154 #endif
155
156         scan = systable_beginscan(heapRel, AccessMethodProcedureIndexId, true,
157                                                                 SNAPSHOT,
158 #if PG_VERSION_NUM >= 90200
159                                                                 4,
160 #else
161                                                                 3,
162 #endif
163                                                                 skey);
164         while (HeapTupleIsValid(tuple = systable_getnext(scan)))
165         {
166                 Form_pg_amproc amprocform = (Form_pg_amproc) GETSTRUCT(tuple);
167
168                 switch(amoid)
169                 {
170                         case BTREE_AM_OID:
171                         case HASH_AM_OID:
172                                 if ( OidIsValid(procOid) )
173                                         elog(ERROR,"Ambiguous support function for type %u (opclass %u)", typid, opfamilyOid);
174                                 procOid = amprocform->amproc;
175                                 break;
176                         default:
177                                 elog(ERROR,"Unsupported access method");
178                 }
179         }
180
181         systable_endscan(scan);
182         heap_close(heapRel, AccessShareLock);
183
184         return procOid;
185 }
186
187 static ProcTypeInfo *cacheProcs = NULL;
188 static int nCacheProcs = 0;
189
190 #ifndef TupleDescAttr
191 #define TupleDescAttr(tupdesc, i)       ((tupdesc)->attrs[(i)])
192 #endif
193
194 static ProcTypeInfo
195 fillProcs(Oid typid)
196 {
197         ProcTypeInfo    info = malloc(sizeof(ProcTypeInfoData));
198
199         if (!info)
200                 elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfoData));
201
202         info->typid = typid;
203         info->typtype = get_typtype(typid);
204
205         if (info->typtype == 'c')
206         {
207                 /* composite type */
208                 TupleDesc               tupdesc;
209                 MemoryContext   oldcontext;
210
211                 tupdesc = lookup_rowtype_tupdesc(typid, -1);
212
213                 if (tupdesc->natts != 2)
214                         elog(ERROR,"Composite type has wrong number of fields");
215                 if (TupleDescAttr(tupdesc, 1)->atttypid != FLOAT4OID)
216                         elog(ERROR,"Second field of composite type is not float4");
217
218                 oldcontext = MemoryContextSwitchTo(TopMemoryContext);
219                 info->tupDesc = CreateTupleDescCopyConstr(tupdesc);
220                 MemoryContextSwitchTo(oldcontext);
221
222                 ReleaseTupleDesc(tupdesc);
223
224                 info->cmpFuncOid = getAMProc(BTREE_AM_OID,
225                                                                          TupleDescAttr(info->tupDesc, 0)->atttypid);
226                 info->hashFuncOid = getAMProc(HASH_AM_OID,
227                                                                           TupleDescAttr(info->tupDesc, 0)->atttypid);
228         }
229         else
230         {
231                 info->tupDesc = NULL;
232
233                 /* plain type */
234                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, typid);
235                 info->hashFuncOid = getAMProc(HASH_AM_OID, typid);
236         }
237
238         get_typlenbyvalalign(typid, &info->typlen, &info->typbyval, &info->typalign);
239         info->hashFuncInited = info->cmpFuncInited = false;
240
241
242         return info;
243 }
244
245 void
246 getFmgrInfoCmp(ProcTypeInfo info)
247 {
248         if ( info->cmpFuncInited == false )
249         {
250                 if ( !OidIsValid(info->cmpFuncOid) )
251                         elog(ERROR, "Could not find cmp function for type %u", info->typid);
252
253                 fmgr_info_cxt( info->cmpFuncOid, &info->cmpFunc, TopMemoryContext );
254                 info->cmpFuncInited = true;
255         }
256 }
257
258 void
259 getFmgrInfoHash(ProcTypeInfo info)
260 {
261         if ( info->hashFuncInited == false )
262         {
263                 if ( !OidIsValid(info->hashFuncOid) )
264                         elog(ERROR, "Could not find hash function for type %u", info->typid);
265
266                 fmgr_info_cxt( info->hashFuncOid, &info->hashFunc, TopMemoryContext );
267                 info->hashFuncInited = true;
268         }
269 }
270
271 static int
272 cmpProcTypeInfo(const void *a, const void *b)
273 {
274         ProcTypeInfo av = *(ProcTypeInfo*)a;
275         ProcTypeInfo bv = *(ProcTypeInfo*)b;
276
277         Assert( av->typid != bv->typid );
278
279         return ( av->typid > bv->typid ) ? 1 : -1;
280 }
281
282 ProcTypeInfo
283 findProcs(Oid typid)
284 {
285         ProcTypeInfo    info = NULL;
286
287         if ( nCacheProcs == 1 )
288         {
289                 if ( cacheProcs[0]->typid == typid )
290                 {
291                         /*cacheProcs[0]->hashFuncInited = cacheProcs[0]->cmpFuncInited = false;*/
292                         return cacheProcs[0];
293                 }
294         }
295         else if ( nCacheProcs > 1 )
296         {
297                 ProcTypeInfo    *StopMiddle;
298                 ProcTypeInfo    *StopLow = cacheProcs,
299                                                 *StopHigh = cacheProcs + nCacheProcs;
300
301                 while (StopLow < StopHigh) {
302                         StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
303                         info = *StopMiddle;
304
305                         if ( info->typid == typid )
306                         {
307                                 /* info->hashFuncInited = info->cmpFuncInited = false; */
308                                 return info;
309                         }
310                         else if ( info->typid < typid )
311                                 StopLow = StopMiddle + 1;
312                         else
313                                 StopHigh = StopMiddle;
314                 }
315
316                 /* not found */
317         } 
318
319         info = fillProcs(typid);
320         if ( nCacheProcs == 0 )
321         {
322                 cacheProcs = malloc(sizeof(ProcTypeInfo));
323
324                 if (!cacheProcs)
325                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo));
326                 else
327                 {
328                         nCacheProcs = 1;
329                         cacheProcs[0] = info;
330                 }
331         }
332         else
333         {
334                 ProcTypeInfo    *cacheProcsTmp = realloc(cacheProcs, (nCacheProcs+1) * sizeof(ProcTypeInfo));
335
336                 if (!cacheProcsTmp)
337                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo) * (nCacheProcs+1));
338                 else
339                 {
340                         cacheProcs = cacheProcsTmp;
341                         cacheProcs[ nCacheProcs ] = info;
342                         nCacheProcs++;
343                         qsort(cacheProcs, nCacheProcs, sizeof(ProcTypeInfo), cmpProcTypeInfo);
344                 }
345         }
346
347         /* info->hashFuncInited = info->cmpFuncInited = false; */
348
349         return info;
350 }
351
352 /*
353  * WARNING. Array2SimpleArray* doesn't copy Datum!
354  */
355 SimpleArray * 
356 Array2SimpleArray(ProcTypeInfo info, ArrayType *a)
357 {
358         SimpleArray     *s = palloc(sizeof(SimpleArray));
359
360         CHECKARRVALID(a);
361
362         if ( info == NULL )
363                 info = findProcs(ARR_ELEMTYPE(a));
364
365         s->info = info;
366         s->df = NULL;
367         s->hash = NULL;
368
369         deconstruct_array(a, info->typid,
370                                                 info->typlen, info->typbyval, info->typalign,
371                                                 &s->elems, NULL, &s->nelems);
372
373         return s;
374 }
375
376 static Datum
377 deconstructCompositeType(ProcTypeInfo info, Datum in, double *weight)
378 {
379         HeapTupleHeader rec = DatumGetHeapTupleHeader(in);
380         HeapTupleData   tuple;
381         Datum                   values[2];
382         bool                    nulls[2];
383
384         /* Build a temporary HeapTuple control structure */
385         tuple.t_len = HeapTupleHeaderGetDatumLength(rec);
386         ItemPointerSetInvalid(&(tuple.t_self));
387         tuple.t_tableOid = InvalidOid;
388         tuple.t_data = rec;
389
390         heap_deform_tuple(&tuple, info->tupDesc, values, nulls);
391         if (nulls[0] || nulls[1])
392                 elog(ERROR, "Both fields in composite type could not be NULL");
393
394         if (weight)
395                 *weight = DatumGetFloat4(values[1]);
396         return values[0];
397 }
398
399 static int
400 cmpArrayElem(const void *a, const void *b, void *arg)
401 {
402         ProcTypeInfo    info = (ProcTypeInfo)arg;
403
404         if (info->tupDesc)
405                 /* composite type */
406                 return DatumGetInt32( FCall2( &info->cmpFunc,
407                                                 deconstructCompositeType(info, *(Datum*)a, NULL),
408                                                 deconstructCompositeType(info, *(Datum*)b, NULL) ) );
409
410         return DatumGetInt32( FCall2( &info->cmpFunc,
411                                                         *(Datum*)a, *(Datum*)b ) );
412 }
413
414 SimpleArray *
415 Array2SimpleArrayS(ProcTypeInfo info, ArrayType *a)
416 {
417         SimpleArray     *s = Array2SimpleArray(info, a);
418
419         if ( s->nelems > 1 )
420         {
421                 getFmgrInfoCmp(s->info);
422
423                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElem, s->info);
424         }
425
426         return s;
427 }
428
429 typedef struct cmpArrayElemData {
430         ProcTypeInfo    info;
431         bool                    hasDuplicate;
432
433 } cmpArrayElemData;
434
435 static int
436 cmpArrayElemArg(const void *a, const void *b, void *arg)
437 {
438         cmpArrayElemData        *data = (cmpArrayElemData*)arg;
439         int                                     res;
440
441         if (data->info->tupDesc)
442                 res =  DatumGetInt32( FCall2( &data->info->cmpFunc,
443                                         deconstructCompositeType(data->info, *(Datum*)a, NULL),
444                                         deconstructCompositeType(data->info, *(Datum*)b, NULL) ) );
445         else
446                 res = DatumGetInt32( FCall2( &data->info->cmpFunc,
447                                                                 *(Datum*)a, *(Datum*)b ) );
448
449         if ( res == 0 )
450                 data->hasDuplicate = true;
451
452         return res;
453 }
454
455 /*
456  * Uniquefy array and calculate TF. Although 
457  * result doesn't depend on normalization, we
458  * normalize TF by length array to have possiblity
459  * to limit estimation for index support.
460  *
461  * Cache signals of needing of TF caclulation
462  */
463
464 SimpleArray *
465 Array2SimpleArrayU(ProcTypeInfo info, ArrayType *a, void *cache)
466 {
467         SimpleArray     *s = Array2SimpleArray(info, a);
468         StatElem        *stat = NULL;
469
470         if ( s->nelems > 0 && cache )
471         {
472                 s->df = palloc(sizeof(double) * s->nelems);
473                 s->df[0] = 1.0; /* init */
474         }
475
476         if ( s->nelems > 1 )
477         {
478                 cmpArrayElemData        data;
479                 int                                     i;
480
481                 getFmgrInfoCmp(s->info);
482                 data.info = s->info;
483                 data.hasDuplicate = false;
484
485                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElemArg, &data);
486
487                 if ( data.hasDuplicate )
488                 {
489                         Datum   *tmp,
490                                         *dr,
491                                         *data;
492                         int             num = s->nelems,
493                                         cmp;
494
495                         data = tmp = dr = s->elems;
496
497                         while (tmp - data < num)
498                         {
499                                 cmp = (tmp == dr) ? 0 : cmpArrayElem(tmp, dr, s->info);
500                                 if ( cmp != 0 )
501                                 {
502                                         *(++dr) = *tmp++;
503                                         if ( cache ) 
504                                                 s->df[ dr - data ] = 1.0;
505                                 }
506                                 else
507                                 {
508                                         if ( cache )
509                                                 s->df[ dr - data ] += 1.0;
510                                         tmp++;
511                                 }
512                         }
513
514                         s->nelems = dr + 1 - s->elems;
515
516                         if ( cache )
517                         {
518                                 int tfm = getTFMethod();
519
520                                 for(i=0;i<s->nelems;i++)
521                                 {
522                                         stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
523                                         if ( stat )
524                                         {
525                                                 switch(tfm)
526                                                 {
527                                                         case TF_LOG:
528                                                                 s->df[i] = (1.0 + log( s->df[i] ));
529                                                                 /* FALLTHROUGH */
530                                                         case TF_N:
531                                                                 s->df[i] *= stat->idf;
532                                                                 break;
533                                                         case TF_CONST:
534                                                                 s->df[i] = stat->idf;
535                                                                 break;
536                                                         default:
537                                                                 elog(ERROR,"Unknown TF method: %d", tfm);
538                                                 }
539                                         }
540                                         else
541                                         {
542                                                 s->df[i] = 0.0; /* unknown word */
543                                         }
544                                 }
545                         }
546                 }
547                 else if ( cache )
548                 {
549                         for(i=0;i<s->nelems;i++)
550                         {
551                                 stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
552                                 if ( stat )
553                                         s->df[i] = stat->idf;
554                                 else
555                                         s->df[i] = 0.0;
556                         }
557                 }
558         }
559         else if (s->nelems > 0 && cache)
560         {
561                 stat = fingArrayStat(cache, s->info->typid, s->elems[0], stat);
562                 if ( stat )
563                         s->df[0] = stat->idf;
564                 else
565                         s->df[0] = 0.0;
566         }
567
568         return s;
569 }
570
571 static int
572 numOfIntersect(SimpleArray *a, SimpleArray *b)
573 {
574         int                             cnt = 0,
575                                         cmp;
576         Datum                   *aptr = a->elems,
577                                         *bptr = b->elems;
578         ProcTypeInfo    info = a->info;
579
580         Assert( a->info->typid == b->info->typid );
581
582         getFmgrInfoCmp(info);
583
584         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
585         {
586                 cmp = cmpArrayElem(aptr, bptr, info);
587                 if ( cmp < 0 )
588                         aptr++;
589                 else if ( cmp > 0 )
590                         bptr++;
591                 else
592                 {
593                         cnt++;
594                         aptr++;
595                         bptr++;
596                 }
597         }
598
599         return cnt;
600 }
601
602 static double
603 TFIDFSml(SimpleArray *a, SimpleArray *b)
604 {
605         int                             cmp;
606         Datum                   *aptr = a->elems,
607                                         *bptr = b->elems;
608         ProcTypeInfo    info = a->info;
609         double                  res = 0.0;
610         double                  suma = 0.0, sumb = 0.0;
611
612         Assert( a->info->typid == b->info->typid );
613         Assert( a->df );
614         Assert( b->df );
615
616         getFmgrInfoCmp(info);
617
618         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
619         {
620                 cmp = cmpArrayElem(aptr, bptr, info);
621                 if ( cmp < 0 )
622                 {
623                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
624                         aptr++;
625                 }
626                 else if ( cmp > 0 )
627                 {
628                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
629                         bptr++;
630                 }
631                 else
632                 {
633                         res += a->df[ aptr - a->elems ] * b->df[ bptr - b->elems ];
634                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
635                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
636                         aptr++;
637                         bptr++;
638                 }
639         }
640
641         /*
642          * Compute last elements
643          */
644         while( aptr - a->elems < a->nelems )
645         {
646                 suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
647                 aptr++;
648         }
649
650         while( bptr - b->elems < b->nelems )
651         {
652                 sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
653                 bptr++;
654         }
655
656         if ( suma > 0.0 && sumb > 0.0 )
657                 res = res / sqrt( suma * sumb );
658         else
659                 res = 0.0;
660
661         return res;
662 }
663
664
665 PG_FUNCTION_INFO_V1(arraysml);
666 Datum   arraysml(PG_FUNCTION_ARGS);
667 Datum
668 arraysml(PG_FUNCTION_ARGS)
669 {
670         ArrayType               *a, *b;
671         SimpleArray             *sa, *sb;
672
673         fcinfo->flinfo->fn_extra = SearchArrayCache(
674                                                         fcinfo->flinfo->fn_extra,
675                                                         fcinfo->flinfo->fn_mcxt,
676                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
677         fcinfo->flinfo->fn_extra = SearchArrayCache(
678                                                         fcinfo->flinfo->fn_extra,
679                                                         fcinfo->flinfo->fn_mcxt,
680                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
681
682         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
683                 elog(ERROR,"Arguments array are not the same type!");
684
685         if (ARRISVOID(a) || ARRISVOID(b))
686                  PG_RETURN_FLOAT4(0.0);
687
688         switch(getSmlType())
689         {
690                 case ST_TFIDF:
691                         PG_RETURN_FLOAT4( TFIDFSml(sa, sb) );
692                         break;
693                 case ST_COSINE:
694                         {
695                                 int                             cnt;
696                                 double                  power;
697
698                                 power = ((double)(sa->nelems)) * ((double)(sb->nelems));
699                                 cnt = numOfIntersect(sa, sb);
700
701                                 PG_RETURN_FLOAT4(  ((double)cnt) / sqrt( power ) );
702                         }
703                         break;
704                 case ST_OVERLAP:
705                         {
706                                 float4 res = (float4)numOfIntersect(sa, sb);
707
708                                 PG_RETURN_FLOAT4(res);
709                         }
710                         break;
711                 default:
712                         elog(ERROR,"Unsupported formula type of similarity");
713         }
714
715         PG_RETURN_FLOAT4(0.0); /* keep compiler quiet */
716 }
717
718 PG_FUNCTION_INFO_V1(arraysmlw);
719 Datum   arraysmlw(PG_FUNCTION_ARGS);
720 Datum
721 arraysmlw(PG_FUNCTION_ARGS)
722 {
723         ArrayType               *a, *b;
724         SimpleArray             *sa, *sb;
725         bool                    useIntersect = PG_GETARG_BOOL(2);
726         double                  numerator = 0.0;
727         double                  denominatorA = 0.0,
728                                         denominatorB = 0.0,
729                                         tmpA, tmpB;
730         int                             cmp;
731         ProcTypeInfo    info;
732         int                             ai = 0, bi = 0;
733
734         fcinfo->flinfo->fn_extra = SearchArrayCache(
735                                                         fcinfo->flinfo->fn_extra,
736                                                         fcinfo->flinfo->fn_mcxt,
737                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
738         fcinfo->flinfo->fn_extra = SearchArrayCache(
739                                                         fcinfo->flinfo->fn_extra,
740                                                         fcinfo->flinfo->fn_mcxt,
741                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
742
743         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
744                 elog(ERROR,"Arguments array are not the same type!");
745
746         if (ARRISVOID(a) || ARRISVOID(b))
747                  PG_RETURN_FLOAT4(0.0);
748
749         info = sa->info;
750         if (info->tupDesc == NULL)
751                 elog(ERROR, "Only weigthed (composite) types should be used");
752         getFmgrInfoCmp(info);
753
754         while(ai < sa->nelems && bi < sb->nelems)
755         {
756                 Datum   ad = deconstructCompositeType(info, sa->elems[ai], &tmpA),
757                                 bd = deconstructCompositeType(info, sb->elems[bi], &tmpB);
758
759                 cmp = DatumGetInt32(FCall2(&info->cmpFunc, ad, bd));
760
761                 if ( cmp < 0 ) {
762                         if (useIntersect == false)
763                                 denominatorA += tmpA * tmpA;
764                         ai++;
765                 } else if ( cmp > 0 ) {
766                         if (useIntersect == false)
767                                 denominatorB += tmpB * tmpB;
768                         bi++;
769                 } else {
770                         denominatorA += tmpA * tmpA;
771                         denominatorB += tmpB * tmpB;
772                         numerator += tmpA * tmpB;
773                         ai++;
774                         bi++;
775                 }
776         }
777
778         if (useIntersect == false) {
779                 while(ai < sa->nelems) {
780                         deconstructCompositeType(info, sa->elems[ai], &tmpA);
781                         denominatorA += tmpA * tmpA;
782                         ai++;
783                 }
784
785                 while(bi < sb->nelems) {
786                         deconstructCompositeType(info, sb->elems[bi], &tmpB);
787                         denominatorB += tmpB * tmpB;
788                         bi++;
789                 }
790         }
791
792         if (numerator != 0.0) {
793                 numerator = numerator / sqrt( denominatorA * denominatorB );
794         }
795
796         PG_RETURN_FLOAT4(numerator);
797 }
798
799 PG_FUNCTION_INFO_V1(arraysml_op);
800 Datum   arraysml_op(PG_FUNCTION_ARGS);
801 Datum
802 arraysml_op(PG_FUNCTION_ARGS)
803 {
804         ArrayType               *a, *b;
805         SimpleArray             *sa, *sb;
806         double                  power = 0.0;
807
808         fcinfo->flinfo->fn_extra = SearchArrayCache(
809                                                         fcinfo->flinfo->fn_extra,
810                                                         fcinfo->flinfo->fn_mcxt,
811                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
812         fcinfo->flinfo->fn_extra = SearchArrayCache(
813                                                         fcinfo->flinfo->fn_extra,
814                                                         fcinfo->flinfo->fn_mcxt,
815                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
816
817         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
818                 elog(ERROR,"Arguments array are not the same type!");
819
820         if (ARRISVOID(a) || ARRISVOID(b))
821                  PG_RETURN_BOOL(false);
822
823         switch(getSmlType())
824         {
825                 case ST_TFIDF:
826                         power = TFIDFSml(sa, sb);
827                         break;
828                 case ST_COSINE:
829                         {
830                                 int                             cnt;
831
832                                 power = sqrt( ((double)(sa->nelems)) * ((double)(sb->nelems)) );
833
834                                 if (  ((double)Min(sa->nelems, sb->nelems)) / power < GetSmlarLimit()  )
835                                         PG_RETURN_BOOL(false);
836
837                                 cnt = numOfIntersect(sa, sb);
838                                 power = ((double)cnt) / power;
839                         }
840                         break;
841                 case ST_OVERLAP:
842                         power = (double)numOfIntersect(sa, sb);
843                         break;
844                 default:
845                         elog(ERROR,"Unsupported formula type of similarity");
846         }
847
848         PG_RETURN_BOOL(power >= GetSmlarLimit());
849 }
850
851 #define QBSIZE          8192
852 static char cachedFormula[QBSIZE];
853 static int      cachedLen  = 0;
854 static void     *cachedPlan = NULL;
855
856 PG_FUNCTION_INFO_V1(arraysml_func);
857 Datum   arraysml_func(PG_FUNCTION_ARGS);
858 Datum
859 arraysml_func(PG_FUNCTION_ARGS)
860 {
861         ArrayType               *a, *b;
862         SimpleArray             *sa, *sb;
863         int                             cnt;
864         float4                  result = -1.0;
865         Oid                             arg[] = {INT4OID, INT4OID, INT4OID};
866         Datum                   pars[3];
867         bool                    isnull;
868         void                    *plan;
869         int                             stat;
870         text                    *formula = PG_GETARG_TEXT_P(2);
871
872         fcinfo->flinfo->fn_extra = SearchArrayCache(
873                                                         fcinfo->flinfo->fn_extra,
874                                                         fcinfo->flinfo->fn_mcxt,
875                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
876         fcinfo->flinfo->fn_extra = SearchArrayCache(
877                                                         fcinfo->flinfo->fn_extra,
878                                                         fcinfo->flinfo->fn_mcxt,
879                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
880
881         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
882                 elog(ERROR,"Arguments array are not the same type!");
883
884         if (ARRISVOID(a) || ARRISVOID(b))
885                  PG_RETURN_BOOL(false);
886
887         cnt = numOfIntersect(sa, sb);
888
889         if ( VARSIZE(formula) - VARHDRSZ > QBSIZE - 1024 )
890                 elog(ERROR,"Formula is too long");
891
892         SPI_connect();
893
894         if ( cachedPlan == NULL || cachedLen != VARSIZE(formula) - VARHDRSZ ||
895                                 memcmp( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ) != 0 )
896         {
897                 char                    *ptr, buf[QBSIZE];
898
899                 *cachedFormula = '\0';
900                 if ( cachedPlan )
901                         SPI_freeplan(cachedPlan);
902                 cachedPlan = NULL;
903                 cachedLen = 0;
904
905                 ptr = stpcpy( buf, "SELECT (" );
906                 memcpy( ptr, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
907                 ptr += VARSIZE(formula) - VARHDRSZ;
908                 ptr = stpcpy( ptr, ")::float4 FROM");
909                 ptr = stpcpy( ptr, " (SELECT $1 ::float8 AS i, $2 ::float8 AS a, $3 ::float8 AS b) AS N;");
910                 *ptr = '\0';
911
912                 plan = SPI_prepare(buf, 3, arg);
913                 if (!plan)
914                         elog(ERROR, "SPI_prepare() failed");
915
916                 cachedPlan = SPI_saveplan(plan);
917                 if (!cachedPlan)
918                         elog(ERROR, "SPI_saveplan() failed");
919
920                 SPI_freeplan(plan);
921                 cachedLen = VARSIZE(formula) - VARHDRSZ;
922                 memcpy( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
923         }
924
925         plan = cachedPlan;
926
927
928         pars[0] = Int32GetDatum( cnt );
929         pars[1] = Int32GetDatum( sa->nelems );
930         pars[2] = Int32GetDatum( sb->nelems );
931
932         stat = SPI_execute_plan(plan, pars, NULL, true, 3);
933         if (stat < 0)
934                 elog(ERROR, "SPI_execute_plan() returns %d", stat);
935
936         if ( SPI_processed > 0)
937                 result = DatumGetFloat4(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
938
939         SPI_finish();
940
941         PG_RETURN_FLOAT4(result);
942 }
943
944 PG_FUNCTION_INFO_V1(array_unique);
945 Datum   array_unique(PG_FUNCTION_ARGS);
946 Datum
947 array_unique(PG_FUNCTION_ARGS)
948 {
949         ArrayType               *a = PG_GETARG_ARRAYTYPE_P(0);
950         ArrayType               *res;
951         SimpleArray             *sa;
952
953         sa = Array2SimpleArrayU(NULL, a, NULL);
954
955         res = construct_array(  sa->elems, 
956                                                         sa->nelems,
957                                                         sa->info->typid,
958                                                         sa->info->typlen,
959                                                         sa->info->typbyval,
960                                                         sa->info->typalign);
961
962         pfree(sa->elems);
963         pfree(sa);
964         PG_FREE_IF_COPY(a, 0);
965
966         PG_RETURN_ARRAYTYPE_P(res);
967 }
968
969 PG_FUNCTION_INFO_V1(inarray);
970 Datum   inarray(PG_FUNCTION_ARGS);
971 Datum
972 inarray(PG_FUNCTION_ARGS)
973 {
974         ArrayType               *a;
975         SimpleArray             *sa;
976         Datum                   query = PG_GETARG_DATUM(1);
977         Oid                             queryTypeOid;
978         Datum                   *StopLow,
979                                         *StopHigh,
980                                         *StopMiddle;
981         int                             cmp;
982
983         fcinfo->flinfo->fn_extra = SearchArrayCache(
984                                                         fcinfo->flinfo->fn_extra,
985                                                         fcinfo->flinfo->fn_mcxt,
986                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
987
988         queryTypeOid = get_fn_expr_argtype(fcinfo->flinfo, 1);
989
990         if ( queryTypeOid == InvalidOid )
991                 elog(ERROR,"inarray: could not determine actual argument type");
992
993         if ( queryTypeOid != sa->info->typid )
994                 elog(ERROR,"inarray: Type of array's element and type of argument are not the same");
995
996         getFmgrInfoCmp(sa->info);
997         StopLow = sa->elems;
998         StopHigh = sa->elems + sa->nelems;
999
1000         while (StopLow < StopHigh)
1001         {
1002                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
1003                 cmp = cmpArrayElem(StopMiddle, &query, sa->info);
1004
1005                 if ( cmp == 0 )
1006                 {
1007                         /* found */
1008                         if ( PG_NARGS() >= 3 )
1009                                 PG_RETURN_DATUM(PG_GETARG_DATUM(2));
1010                         PG_RETURN_FLOAT4(1.0);
1011                 }
1012                 else if (cmp < 0)
1013                         StopLow = StopMiddle + 1;
1014                 else
1015                         StopHigh = StopMiddle;
1016         }
1017
1018         if ( PG_NARGS() >= 4 )
1019                 PG_RETURN_DATUM(PG_GETARG_DATUM(3));
1020         PG_RETURN_FLOAT4(0.0);
1021 }