1.0 version
[smlar.git] / smlar.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/genam.h"
5 #include "access/heapam.h"
6 #include "access/nbtree.h"
7 #include "catalog/indexing.h"
8 #include "catalog/pg_am.h"
9 #include "catalog/pg_amproc.h"
10 #include "catalog/pg_cast.h"
11 #include "catalog/pg_opclass.h"
12 #include "catalog/pg_type.h"
13 #include "executor/spi.h"
14 #include "utils/fmgroids.h"
15 #include "utils/lsyscache.h"
16 #include "utils/memutils.h"
17 #include "utils/tqual.h"
18 #include "utils/syscache.h"
19 #include "utils/typcache.h"
20
21 PG_MODULE_MAGIC;
22
23 static Oid
24 getDefaultOpclass(Oid amoid, Oid typid) 
25 {
26         ScanKeyData     skey;
27         SysScanDesc     scan;
28         HeapTuple       tuple;
29         Relation        heapRel;
30         Oid                     opclassOid = InvalidOid;
31
32         heapRel = heap_open(OperatorClassRelationId, AccessShareLock);  
33
34         ScanKeyInit(&skey,
35                                 Anum_pg_opclass_opcmethod,
36                                 BTEqualStrategyNumber,  F_OIDEQ,
37                                 ObjectIdGetDatum(amoid));
38
39         scan = systable_beginscan(heapRel, 
40                                                                 OpclassAmNameNspIndexId, true, 
41                                                                 SnapshotNow, 1, &skey);
42
43         while (HeapTupleIsValid((tuple = systable_getnext(scan))))
44         {
45                 Form_pg_opclass opclass = (Form_pg_opclass)GETSTRUCT(tuple);
46
47                 if ( opclass->opcintype == typid && opclass->opcdefault )
48                 {
49                         if ( OidIsValid(opclassOid) )
50                                 elog(ERROR, "Ambiguous opclass for type %u (access method %u)", typid, amoid); 
51                         opclassOid = HeapTupleGetOid(tuple);
52                 }
53         }
54
55         systable_endscan(scan);
56         heap_close(heapRel, AccessShareLock);
57
58         return opclassOid;
59 }
60
61 static Oid
62 getAMProc(Oid amoid, Oid typid) 
63 {
64         Oid             opclassOid = getDefaultOpclass(amoid, typid);
65         Oid             procOid = InvalidOid;
66         Oid             opfamilyOid;
67         ScanKeyData     skey[4];
68         SysScanDesc     scan;
69         HeapTuple       tuple;
70         Relation        heapRel;
71
72         if ( !OidIsValid(opclassOid) )
73         {
74                 typid = getBaseType(typid);
75                 opclassOid = getDefaultOpclass(amoid, typid);
76         }
77
78         if ( !OidIsValid(opclassOid) )
79         {
80                 CatCList        *catlist;
81                 int                     i;
82
83                 /*
84                  * Search binary-coercible type
85                  */
86                 catlist = SearchSysCacheList(CASTSOURCETARGET, 1,
87                                                                                 ObjectIdGetDatum(typid),
88                                                                                 0, 0, 0);
89
90                 for (i = 0; i < catlist->n_members; i++)
91                 {
92                         HeapTuple       tuple = &catlist->members[i]->tuple;
93                         Form_pg_cast    castForm = (Form_pg_cast)GETSTRUCT(tuple);
94
95                         if ( castForm->castfunc == InvalidOid && castForm->castcontext == COERCION_CODE_IMPLICIT )
96                         {
97                                 typid = castForm->casttarget;
98                                 opclassOid = getDefaultOpclass(amoid, typid);
99                                 if( OidIsValid(opclassOid) )
100                                         break;
101                         }
102                 }
103
104                 ReleaseSysCacheList(catlist);
105         }
106         
107         if ( !OidIsValid(opclassOid) )
108                 return InvalidOid;
109
110         opfamilyOid = get_opclass_family(opclassOid);
111
112         heapRel = heap_open(AccessMethodProcedureRelationId, AccessShareLock);
113         ScanKeyInit(&skey[0],
114                                 Anum_pg_amproc_amprocfamily,
115                                 BTEqualStrategyNumber, F_OIDEQ,
116                                 ObjectIdGetDatum(opfamilyOid));
117         ScanKeyInit(&skey[1],
118                                 Anum_pg_amproc_amproclefttype,
119                                 BTEqualStrategyNumber, F_OIDEQ,
120                                 ObjectIdGetDatum(typid));
121         ScanKeyInit(&skey[2],
122                                 Anum_pg_amproc_amprocrighttype,
123                                 BTEqualStrategyNumber, F_OIDEQ,
124                                 ObjectIdGetDatum(typid));
125 #if PG_VERSION_NUM >= 90200
126         ScanKeyInit(&skey[3],
127                                 Anum_pg_amproc_amprocnum,
128                                 BTEqualStrategyNumber, F_OIDEQ,
129                                 Int32GetDatum(BTORDER_PROC));
130 #endif
131
132         scan = systable_beginscan(heapRel, AccessMethodProcedureIndexId, true,
133                                                                 SnapshotNow, 
134 #if PG_VERSION_NUM >= 90200
135                                                                 4,
136 #else
137                                                                 3,
138 #endif
139                                                                 skey);
140         while (HeapTupleIsValid(tuple = systable_getnext(scan)))
141         {
142                 Form_pg_amproc amprocform = (Form_pg_amproc) GETSTRUCT(tuple);
143
144                 switch(amoid)
145                 {
146                         case BTREE_AM_OID:
147                         case HASH_AM_OID:
148                                 if ( OidIsValid(procOid) )
149                                         elog(ERROR,"Ambiguous support function for type %u (opclass %u)", typid, opfamilyOid);
150                                 procOid = amprocform->amproc;
151                                 break;
152                         default:
153                                 elog(ERROR,"Unsupported access method");
154                 }
155         }
156
157         systable_endscan(scan);
158         heap_close(heapRel, AccessShareLock);
159
160         return procOid;
161 }
162
163 static ProcTypeInfo *cacheProcs = NULL;
164 static int nCacheProcs = 0;
165
166 static ProcTypeInfo
167 fillProcs(Oid typid)
168 {
169         ProcTypeInfo    info = malloc(sizeof(ProcTypeInfoData));
170
171         if (!info)
172                 elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfoData));
173
174         info->typid = typid;
175         info->typtype = get_typtype(typid);
176
177         if (info->typtype == 'c')
178         {
179                 /* composite type */
180                 TupleDesc               tupdesc;
181                 MemoryContext   oldcontext;
182
183                 tupdesc = lookup_rowtype_tupdesc(typid, -1);
184
185                 if (tupdesc->natts != 2)
186                         elog(ERROR,"Composite type has wrong number of fields");
187                 if (tupdesc->attrs[1]->atttypid != FLOAT4OID)
188                         elog(ERROR,"Second field of composite type is not float4");
189
190                 oldcontext = MemoryContextSwitchTo(TopMemoryContext);
191                 info->tupDesc = CreateTupleDescCopyConstr(tupdesc);
192                 MemoryContextSwitchTo(oldcontext);
193
194                 ReleaseTupleDesc(tupdesc);
195
196                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, info->tupDesc->attrs[0]->atttypid);
197                 info->hashFuncOid = getAMProc(HASH_AM_OID, info->tupDesc->attrs[0]->atttypid);
198         } 
199         else 
200         {
201                 info->tupDesc = NULL;
202
203                 /* plain type */
204                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, typid);
205                 info->hashFuncOid = getAMProc(HASH_AM_OID, typid);
206         }
207
208         get_typlenbyvalalign(typid, &info->typlen, &info->typbyval, &info->typalign);
209         info->hashFuncInited = info->cmpFuncInited = false;
210
211
212         return info;
213 }
214
215 void
216 getFmgrInfoCmp(ProcTypeInfo info)
217 {
218         if ( info->cmpFuncInited == false )
219         {
220                 if ( !OidIsValid(info->cmpFuncOid) )
221                         elog(ERROR, "Could not find cmp function for type %u", info->typid);
222
223                 fmgr_info_cxt( info->cmpFuncOid, &info->cmpFunc, TopMemoryContext );
224                 info->cmpFuncInited = true;
225         }
226 }
227
228 void
229 getFmgrInfoHash(ProcTypeInfo info)
230 {
231         if ( info->hashFuncInited == false )
232         {
233                 if ( !OidIsValid(info->hashFuncOid) )
234                         elog(ERROR, "Could not find hash function for type %u", info->typid);
235
236                 fmgr_info_cxt( info->hashFuncOid, &info->hashFunc, TopMemoryContext );
237                 info->hashFuncInited = true;
238         }
239 }
240
241 static int
242 cmpProcTypeInfo(const void *a, const void *b)
243 {
244         ProcTypeInfo av = *(ProcTypeInfo*)a;
245         ProcTypeInfo bv = *(ProcTypeInfo*)b;
246
247         Assert( av->typid != bv->typid );
248
249         return ( av->typid > bv->typid ) ? 1 : -1; 
250 }
251
252 ProcTypeInfo
253 findProcs(Oid typid)
254 {
255         ProcTypeInfo    info = NULL;
256
257         if ( nCacheProcs == 1 )
258         {
259                 if ( cacheProcs[0]->typid == typid ) 
260                 {
261                         /*cacheProcs[0]->hashFuncInited = cacheProcs[0]->cmpFuncInited = false;*/
262                         return cacheProcs[0];
263                 }
264         }
265         else if ( nCacheProcs > 1 )
266         {
267                 ProcTypeInfo    *StopMiddle;
268                 ProcTypeInfo    *StopLow = cacheProcs, 
269                                                 *StopHigh = cacheProcs + nCacheProcs;
270
271                 while (StopLow < StopHigh) {
272                         StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
273                         info = *StopMiddle;
274
275                         if ( info->typid == typid )
276                         {
277                                 /* info->hashFuncInited = info->cmpFuncInited = false; */
278                                 return info;
279                         }
280                         else if ( info->typid < typid )
281                                 StopLow = StopMiddle + 1;
282                         else
283                                 StopHigh = StopMiddle;
284                 }
285
286                 /* not found */
287         } 
288
289         info = fillProcs(typid);
290         if ( nCacheProcs == 0 ) 
291         {
292                 cacheProcs = malloc(sizeof(ProcTypeInfo));
293
294                 if (!cacheProcs)
295                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo));
296                 else
297                 {
298                         nCacheProcs = 1;
299                         cacheProcs[0] = info;
300                 }
301         }
302         else
303         {
304                 ProcTypeInfo    *cacheProcsTmp = realloc(cacheProcs, (nCacheProcs+1) * sizeof(ProcTypeInfo));
305
306                 if (!cacheProcsTmp)
307                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo) * (nCacheProcs+1));
308                 else
309                 {
310                         cacheProcs = cacheProcsTmp;
311                         cacheProcs[ nCacheProcs ] = info;
312                         nCacheProcs++;
313                         qsort(cacheProcs, nCacheProcs, sizeof(ProcTypeInfo), cmpProcTypeInfo);
314                 }
315         }
316
317         /* info->hashFuncInited = info->cmpFuncInited = false; */
318
319         return info;
320 }
321
322 /*
323  * WARNING. Array2SimpleArray* doesn't copy Datum!
324  */
325 SimpleArray * 
326 Array2SimpleArray(ProcTypeInfo info, ArrayType *a)
327 {
328         SimpleArray     *s = palloc(sizeof(SimpleArray));
329
330         CHECKARRVALID(a);
331
332         if ( info == NULL )
333                 info = findProcs(ARR_ELEMTYPE(a));
334
335         s->info = info;
336         s->df = NULL;
337         s->hash = NULL;
338
339         deconstruct_array(a, info->typid,
340                                                 info->typlen, info->typbyval, info->typalign,
341                                                 &s->elems, NULL, &s->nelems);
342
343         return s;
344 }
345
346 static Datum
347 deconstructCompositeType(ProcTypeInfo info, Datum in, double *weight)
348 {
349         HeapTupleHeader rec = DatumGetHeapTupleHeader(in);
350         HeapTupleData   tuple;
351         Datum                   values[2];
352         bool                    nulls[2];
353
354         /* Build a temporary HeapTuple control structure */
355         tuple.t_len = HeapTupleHeaderGetDatumLength(rec);
356         ItemPointerSetInvalid(&(tuple.t_self));
357         tuple.t_tableOid = InvalidOid;
358         tuple.t_data = rec;
359
360         heap_deform_tuple(&tuple, info->tupDesc, values, nulls);
361         if (nulls[0] || nulls[1])
362                 elog(ERROR, "Both fields in composite type could not be NULL");
363
364         if (weight)
365                 *weight = DatumGetFloat4(values[1]);
366         return values[0];
367 }
368
369 static int
370 cmpArrayElem(const void *a, const void *b, void *arg)
371 {
372         ProcTypeInfo    info = (ProcTypeInfo)arg;
373
374         if (info->tupDesc) 
375                 /* composite type */
376                 return DatumGetInt32( FCall2( &info->cmpFunc,
377                                                                                         deconstructCompositeType(info, *(Datum*)a, NULL),
378                                                                                         deconstructCompositeType(info, *(Datum*)b, NULL) ) );
379
380         return DatumGetInt32( FCall2( &info->cmpFunc,
381                                                         *(Datum*)a, *(Datum*)b ) );
382 }
383
384 SimpleArray *
385 Array2SimpleArrayS(ProcTypeInfo info, ArrayType *a)
386 {
387         SimpleArray     *s = Array2SimpleArray(info, a);
388
389         if ( s->nelems > 1 )
390         {
391                 getFmgrInfoCmp(s->info);
392
393                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElem, s->info);
394         }
395
396         return s;
397 }
398
399 typedef struct cmpArrayElemData {
400         ProcTypeInfo    info;
401         bool                    hasDuplicate;
402
403 } cmpArrayElemData;
404
405 static int
406 cmpArrayElemArg(const void *a, const void *b, void *arg)
407 {
408     cmpArrayElemData    *data = (cmpArrayElemData*)arg;
409         int res;
410
411         if (data->info->tupDesc)
412                 res =  DatumGetInt32( FCall2( &data->info->cmpFunc,
413                                                                                         deconstructCompositeType(data->info, *(Datum*)a, NULL),
414                                                                                         deconstructCompositeType(data->info, *(Datum*)b, NULL) ) );
415         else
416                 res = DatumGetInt32( FCall2( &data->info->cmpFunc,
417                                                                 *(Datum*)a, *(Datum*)b ) );
418
419         if ( res == 0 )
420                 data->hasDuplicate = true;
421
422         return res;
423 }
424
425 /*
426  * Uniquefy array and calculate TF. Although 
427  * result doesn't depend on normalization, we
428  * normalize TF by length array to have possiblity
429  * to limit estimation for index support.
430  *
431  * Cache signals of needing of TF caclulation
432  */
433
434 SimpleArray *
435 Array2SimpleArrayU(ProcTypeInfo info, ArrayType *a, void *cache)
436 {
437         SimpleArray     *s = Array2SimpleArray(info, a);
438         StatElem        *stat = NULL;
439
440         if ( s->nelems > 0 && cache )
441         {
442                 s->df = palloc(sizeof(double) * s->nelems);
443                 s->df[0] = 1.0; /* init */
444         }
445
446         if ( s->nelems > 1 )
447         {
448                 cmpArrayElemData        data;
449                 int                             i;
450
451                 getFmgrInfoCmp(s->info);
452                 data.info = s->info;
453                 data.hasDuplicate = false;
454
455                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElemArg, &data);
456
457                 if ( data.hasDuplicate )
458                 {
459                         Datum   *tmp,
460                                         *dr,
461                                         *data;
462                         int             num = s->nelems,
463                                         cmp;
464
465                         data = tmp = dr = s->elems;
466
467                         while (tmp - data < num)
468                         {
469                                 cmp = (tmp == dr) ? 0 : cmpArrayElem(tmp, dr, s->info);
470                                 if ( cmp != 0 )
471                                 {
472                                         *(++dr) = *tmp++;
473                                         if ( cache ) 
474                                                 s->df[ dr - data ] = 1.0;
475                                 }
476                                 else
477                                 {
478                                         if ( cache )
479                                                 s->df[ dr - data ] += 1.0;
480                                         tmp++;
481                                 }
482                         }
483
484                         s->nelems = dr + 1 - s->elems;
485
486                         if ( cache )
487                         {
488                                 int tfm = getTFMethod();
489
490                                 for(i=0;i<s->nelems;i++)
491                                 {
492                                         stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
493                                         if ( stat )
494                                         {
495                                                 switch(tfm)
496                                                 {
497                                                         case TF_LOG:
498                                                                 s->df[i] = (1.0 + log( s->df[i] ));
499                                                         case TF_N:
500                                                                 s->df[i] *= stat->idf;
501                                                                 break;
502                                                         case TF_CONST:
503                                                                 s->df[i] = stat->idf;
504                                                                 break;
505                                                         default:
506                                                                 elog(ERROR,"Unknown TF method: %d", tfm);
507                                                 }       
508                                         }
509                                         else
510                                         {
511                                                 s->df[i] = 0.0; /* unknown word */
512                                         }
513                                 }       
514                         }
515                 }
516                 else if ( cache )
517                 {
518                         for(i=0;i<s->nelems;i++)
519                         {
520                                 stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
521                                 if ( stat )
522                                         s->df[i] = stat->idf;
523                                 else
524                                         s->df[i] = 0.0;
525                         }
526                 }
527         }
528         else if (s->nelems > 0 && cache) 
529         {
530                 stat = fingArrayStat(cache, s->info->typid, s->elems[0], stat);
531                 if ( stat )
532                         s->df[0] = stat->idf;
533                 else
534                         s->df[0] = 0.0;
535         }
536
537         return s;
538 }
539
540 static int
541 numOfIntersect(SimpleArray *a, SimpleArray *b)
542 {
543         int                             cnt = 0,
544                                         cmp;
545         Datum                   *aptr = a->elems,
546                                         *bptr = b->elems;
547         ProcTypeInfo    info = a->info;
548
549         Assert( a->info->typid == b->info->typid );
550
551         getFmgrInfoCmp(info);
552
553         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
554         {
555                 cmp = cmpArrayElem(aptr, bptr, info);
556                 if ( cmp < 0 )
557                         aptr++;
558                 else if ( cmp > 0 )
559                         bptr++;
560                 else
561                 {
562                         cnt++;
563                         aptr++;
564                         bptr++;
565                 }
566         }
567
568         return cnt;
569 }
570
571 static double
572 TFIDFSml(SimpleArray *a, SimpleArray *b)
573 {
574         int                             cmp;
575         Datum                   *aptr = a->elems,
576                                         *bptr = b->elems;
577         ProcTypeInfo    info = a->info;
578         double                  res = 0.0;
579         double                  suma = 0.0, sumb = 0.0;
580
581         Assert( a->info->typid == b->info->typid );
582         Assert( a->df );
583         Assert( b->df );
584
585         getFmgrInfoCmp(info);
586
587         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
588         {
589                 cmp = cmpArrayElem(aptr, bptr, info);
590                 if ( cmp < 0 )
591                 {
592                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
593                         aptr++;
594                 }
595                 else if ( cmp > 0 )
596                 {
597                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
598                         bptr++;
599                 }
600                 else
601                 {
602                         res += a->df[ aptr - a->elems ] * b->df[ bptr - b->elems ];
603                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
604                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
605                         aptr++;
606                         bptr++;
607                 }
608         }
609
610         /*
611          * Compute last elements
612          */
613         while( aptr - a->elems < a->nelems )
614         {
615                 suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
616                 aptr++;
617         }
618
619         while( bptr - b->elems < b->nelems )
620         {
621                 sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
622                 bptr++;
623         }
624
625         if ( suma > 0.0 && sumb > 0.0 )
626                 res = res / sqrt( suma * sumb );
627         else
628                 res = 0.0;
629
630         return res;
631 }
632
633
634 PG_FUNCTION_INFO_V1(arraysml);
635 Datum   arraysml(PG_FUNCTION_ARGS);
636 Datum
637 arraysml(PG_FUNCTION_ARGS)
638 {
639         ArrayType               *a, *b;
640         SimpleArray             *sa, *sb;
641         
642         fcinfo->flinfo->fn_extra = SearchArrayCache( 
643                                                         fcinfo->flinfo->fn_extra,
644                                                         fcinfo->flinfo->fn_mcxt,
645                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
646         fcinfo->flinfo->fn_extra = SearchArrayCache( 
647                                                         fcinfo->flinfo->fn_extra,
648                                                         fcinfo->flinfo->fn_mcxt,
649                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
650
651         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
652                 elog(ERROR,"Arguments array are not the same type!");
653
654         if (ARRISVOID(a) || ARRISVOID(b))
655                  PG_RETURN_FLOAT4(0.0);
656
657         switch(getSmlType())
658         {
659                 case    ST_TFIDF:
660                         PG_RETURN_FLOAT4( TFIDFSml(sa, sb) );
661                         break;
662                 case    ST_COSINE:
663                         {
664                                 int                             cnt;
665                                 double                  power;
666
667                                 power = ((double)(sa->nelems)) * ((double)(sb->nelems));
668                                 cnt = numOfIntersect(sa, sb);
669
670                                 PG_RETURN_FLOAT4(  ((double)cnt) / sqrt( power ) );
671                         }
672                         break;
673                 case    ST_OVERLAP:
674                         {
675                                 float4 res = (float4)numOfIntersect(sa, sb);
676
677                                 PG_RETURN_FLOAT4(res);
678                         }
679                         break;
680                 default:
681                         elog(ERROR,"Unsupported formula type of similarity");
682         }
683
684         PG_RETURN_FLOAT4(0.0); /* keep compiler quiet */ 
685 }
686
687 PG_FUNCTION_INFO_V1(arraysmlw);
688 Datum   arraysmlw(PG_FUNCTION_ARGS);
689 Datum
690 arraysmlw(PG_FUNCTION_ARGS)
691 {
692         ArrayType               *a, *b;
693         SimpleArray             *sa, *sb;
694         bool                    useIntersect = PG_GETARG_BOOL(2);
695         double                  numerator = 0.0;
696         double                  denominatorA = 0.0,
697                                         denominatorB = 0.0,
698                                         tmpA, tmpB;
699         int                             cmp;
700         ProcTypeInfo    info;
701         int                             ai = 0, bi = 0;
702
703         
704         fcinfo->flinfo->fn_extra = SearchArrayCache( 
705                                                         fcinfo->flinfo->fn_extra,
706                                                         fcinfo->flinfo->fn_mcxt,
707                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
708         fcinfo->flinfo->fn_extra = SearchArrayCache( 
709                                                         fcinfo->flinfo->fn_extra,
710                                                         fcinfo->flinfo->fn_mcxt,
711                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
712
713         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
714                 elog(ERROR,"Arguments array are not the same type!");
715
716         if (ARRISVOID(a) || ARRISVOID(b))
717                  PG_RETURN_FLOAT4(0.0);
718
719         info = sa->info;
720         if (info->tupDesc == NULL)
721                 elog(ERROR, "Only weigthed (composite) types should be used");
722         getFmgrInfoCmp(info);
723
724         while(ai < sa->nelems && bi < sb->nelems)
725         {
726                 Datum   ad = deconstructCompositeType(info, sa->elems[ai], &tmpA),
727                                 bd = deconstructCompositeType(info, sb->elems[bi], &tmpB);
728
729                 cmp = DatumGetInt32(FCall2(&info->cmpFunc, ad, bd));
730
731                 if ( cmp < 0 ) {
732                         if (useIntersect == false)
733                                 denominatorA += tmpA * tmpA;
734                         ai++;
735                 } else if ( cmp > 0 ) {
736                         if (useIntersect == false)
737                                 denominatorB += tmpB * tmpB;
738                         bi++;
739                 } else {
740                         denominatorA += tmpA * tmpA;
741                         denominatorB += tmpB * tmpB;
742                         numerator += tmpA * tmpB;
743                         ai++;
744                         bi++;
745                 }
746         }
747
748         if (useIntersect == false) {
749                 while(ai < sa->nelems) {
750                         deconstructCompositeType(info, sa->elems[ai], &tmpA);
751                         denominatorA += tmpA * tmpA;
752                         ai++;
753                 }
754
755                 while(bi < sb->nelems) {
756                         deconstructCompositeType(info, sb->elems[bi], &tmpB);
757                         denominatorB += tmpB * tmpB;
758                         bi++;
759                 }
760         }
761
762         if (numerator != 0.0) {
763                 numerator = numerator / sqrt( denominatorA * denominatorB );
764         }
765
766         PG_RETURN_FLOAT4(numerator);
767 }
768
769 PG_FUNCTION_INFO_V1(arraysml_op);
770 Datum   arraysml_op(PG_FUNCTION_ARGS);
771 Datum
772 arraysml_op(PG_FUNCTION_ARGS)
773 {
774         ArrayType               *a, *b;
775         SimpleArray             *sa, *sb;
776         double                  power = 0.0;
777         
778         fcinfo->flinfo->fn_extra = SearchArrayCache( 
779                                                         fcinfo->flinfo->fn_extra,
780                                                         fcinfo->flinfo->fn_mcxt,
781                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
782         fcinfo->flinfo->fn_extra = SearchArrayCache( 
783                                                         fcinfo->flinfo->fn_extra,
784                                                         fcinfo->flinfo->fn_mcxt,
785                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
786
787         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
788                 elog(ERROR,"Arguments array are not the same type!");
789
790         if (ARRISVOID(a) || ARRISVOID(b))
791                  PG_RETURN_BOOL(false);
792
793         switch(getSmlType())
794         {
795                 case    ST_TFIDF:
796                         power = TFIDFSml(sa, sb);
797                         break;
798                 case    ST_COSINE:
799                         {
800                                 int                             cnt;
801
802                                 power = sqrt( ((double)(sa->nelems)) * ((double)(sb->nelems)) );
803
804                                 if (  ((double)Min(sa->nelems, sb->nelems)) / power < GetSmlarLimit()  )
805                                         PG_RETURN_BOOL(false);
806
807                                 cnt = numOfIntersect(sa, sb);
808                                 power = ((double)cnt) / power;
809                         }
810                         break;
811                 case    ST_OVERLAP:
812                         power = (double)numOfIntersect(sa, sb);
813                         break;
814                 default:
815                         elog(ERROR,"Unsupported formula type of similarity");
816         }
817
818         PG_RETURN_BOOL(power >= GetSmlarLimit());
819 }
820
821 #define QBSIZE          8192
822 static char cachedFormula[QBSIZE];
823 static int      cachedLen  = 0;
824 static void     *cachedPlan = NULL;
825
826 PG_FUNCTION_INFO_V1(arraysml_func);
827 Datum   arraysml_func(PG_FUNCTION_ARGS);
828 Datum
829 arraysml_func(PG_FUNCTION_ARGS)
830 {
831         ArrayType               *a, *b;
832         SimpleArray             *sa, *sb;
833         int                             cnt;
834         float4                  result = -1.0;
835         Oid                             arg[] = {INT4OID, INT4OID, INT4OID};
836         Datum                   pars[3];
837         bool                    isnull;
838         void                    *plan;
839         int                             stat;
840         text                    *formula = PG_GETARG_TEXT_P(2); 
841         
842         fcinfo->flinfo->fn_extra = SearchArrayCache( 
843                                                         fcinfo->flinfo->fn_extra,
844                                                         fcinfo->flinfo->fn_mcxt,
845                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
846         fcinfo->flinfo->fn_extra = SearchArrayCache( 
847                                                         fcinfo->flinfo->fn_extra,
848                                                         fcinfo->flinfo->fn_mcxt,
849                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
850
851         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
852                 elog(ERROR,"Arguments array are not the same type!");
853
854         if (ARRISVOID(a) || ARRISVOID(b))
855                  PG_RETURN_BOOL(false);
856
857         cnt = numOfIntersect(sa, sb);
858
859         if ( VARSIZE(formula) - VARHDRSZ > QBSIZE - 1024 )
860                 elog(ERROR,"Formula is too long");
861
862         SPI_connect();
863
864         
865         if ( cachedPlan == NULL || cachedLen != VARSIZE(formula) - VARHDRSZ ||
866                                 memcmp( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ) != 0 )
867         {
868                 char                    *ptr, buf[QBSIZE];
869
870                 *cachedFormula = '\0';
871                 if ( cachedPlan )
872                         SPI_freeplan(cachedPlan);
873                 cachedPlan = NULL;
874                 cachedLen = 0;
875
876                 ptr = stpcpy( buf, "SELECT (" );
877                 memcpy( ptr, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
878                 ptr += VARSIZE(formula) - VARHDRSZ;
879                 ptr = stpcpy( ptr, ")::float4 FROM");
880                 ptr = stpcpy( ptr, " (SELECT $1 ::float8 AS i, $2 ::float8 AS a, $3 ::float8 AS b) AS N;");
881                 *ptr = '\0';
882
883                 plan = SPI_prepare(buf, 3, arg);
884                 if (!plan)
885                         elog(ERROR, "SPI_prepare() failed");
886
887                 cachedPlan = SPI_saveplan(plan);
888                 if (!cachedPlan)
889                         elog(ERROR, "SPI_saveplan() failed");
890
891                 SPI_freeplan(plan);
892                 cachedLen = VARSIZE(formula) - VARHDRSZ;
893                 memcpy( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ); 
894         }
895
896         plan = cachedPlan;
897
898
899         pars[0] = Int32GetDatum( cnt );
900         pars[1] = Int32GetDatum( sa->nelems );
901         pars[2] = Int32GetDatum( sb->nelems );
902
903         stat = SPI_execute_plan(plan, pars, NULL, true, 3);
904         if (stat < 0)
905                 elog(ERROR, "SPI_execute_plan() returns %d", stat);
906
907         if ( SPI_processed > 0)
908                 result = DatumGetFloat4(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
909
910         SPI_finish();
911
912         PG_RETURN_FLOAT4(result);
913 }
914
915 PG_FUNCTION_INFO_V1(array_unique);
916 Datum   array_unique(PG_FUNCTION_ARGS);
917 Datum
918 array_unique(PG_FUNCTION_ARGS)
919 {
920         ArrayType               *a = PG_GETARG_ARRAYTYPE_P(0);
921         ArrayType               *res;
922         SimpleArray             *sa;
923
924         sa = Array2SimpleArrayU(NULL, a, NULL);
925
926         res = construct_array(  sa->elems, 
927                                                         sa->nelems,
928                                                         sa->info->typid,
929                                                         sa->info->typlen,
930                                                         sa->info->typbyval,
931                                                         sa->info->typalign);
932
933         pfree(sa->elems);
934         pfree(sa);
935         PG_FREE_IF_COPY(a, 0);
936
937         PG_RETURN_ARRAYTYPE_P(res);
938 }
939
940 PG_FUNCTION_INFO_V1(inarray);
941 Datum   inarray(PG_FUNCTION_ARGS);
942 Datum
943 inarray(PG_FUNCTION_ARGS)
944 {
945         ArrayType               *a;
946         SimpleArray             *sa;
947         Datum                   query = PG_GETARG_DATUM(1); 
948         Oid                             queryTypeOid;
949         Datum                   *StopLow,
950                                         *StopHigh,
951                                         *StopMiddle;
952         int                             cmp;
953
954         fcinfo->flinfo->fn_extra = SearchArrayCache( 
955                                                         fcinfo->flinfo->fn_extra,
956                                                         fcinfo->flinfo->fn_mcxt,
957                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
958
959         queryTypeOid = get_fn_expr_argtype(fcinfo->flinfo, 1);
960
961         if ( queryTypeOid == InvalidOid )
962                 elog(ERROR,"inarray: could not determine actual argument type");
963
964         if ( queryTypeOid != sa->info->typid )
965                 elog(ERROR,"inarray: Type of array's element and type of argument are not the same");
966
967         getFmgrInfoCmp(sa->info);
968         StopLow = sa->elems;
969         StopHigh = sa->elems + sa->nelems;
970
971         while (StopLow < StopHigh) 
972         {
973                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
974                 cmp = cmpArrayElem(StopMiddle, &query, sa->info);
975
976                 if ( cmp == 0 )
977                 {
978                         /* found */
979                         if ( PG_NARGS() >= 3 )
980                                 PG_RETURN_DATUM(PG_GETARG_DATUM(2));
981                         PG_RETURN_FLOAT4(1.0);
982                 }
983                 else if (cmp < 0)
984                         StopLow = StopMiddle + 1;
985                 else
986                         StopHigh = StopMiddle;
987         }
988
989         if ( PG_NARGS() >= 4 )
990                 PG_RETURN_DATUM(PG_GETARG_DATUM(3));
991         PG_RETURN_FLOAT4(0.0);
992 }