7b34c6ee6a615ca3528a831d01337d49ab784ac5
[smlar.git] / smlar.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/genam.h"
5 #include "access/heapam.h"
6 #include "access/htup_details.h"
7 #include "access/nbtree.h"
8 #include "catalog/indexing.h"
9 #include "catalog/pg_am.h"
10 #include "catalog/pg_amproc.h"
11 #include "catalog/pg_cast.h"
12 #include "catalog/pg_opclass.h"
13 #include "catalog/pg_type.h"
14 #include "executor/spi.h"
15 #include "utils/catcache.h"
16 #include "utils/fmgroids.h"
17 #include "utils/lsyscache.h"
18 #include "utils/memutils.h"
19 #include "utils/tqual.h"
20 #include "utils/syscache.h"
21 #include "utils/typcache.h"
22
23 PG_MODULE_MAGIC;
24
25 static Oid
26 getDefaultOpclass(Oid amoid, Oid typid) 
27 {
28         ScanKeyData     skey;
29         SysScanDesc     scan;
30         HeapTuple       tuple;
31         Relation        heapRel;
32         Oid                     opclassOid = InvalidOid;
33
34         heapRel = heap_open(OperatorClassRelationId, AccessShareLock);  
35
36         ScanKeyInit(&skey,
37                                 Anum_pg_opclass_opcmethod,
38                                 BTEqualStrategyNumber,  F_OIDEQ,
39                                 ObjectIdGetDatum(amoid));
40
41         scan = systable_beginscan(heapRel, 
42                                                                 OpclassAmNameNspIndexId, true, 
43                                                                 SnapshotNow, 1, &skey);
44
45         while (HeapTupleIsValid((tuple = systable_getnext(scan))))
46         {
47                 Form_pg_opclass opclass = (Form_pg_opclass)GETSTRUCT(tuple);
48
49                 if ( opclass->opcintype == typid && opclass->opcdefault )
50                 {
51                         if ( OidIsValid(opclassOid) )
52                                 elog(ERROR, "Ambiguous opclass for type %u (access method %u)", typid, amoid); 
53                         opclassOid = HeapTupleGetOid(tuple);
54                 }
55         }
56
57         systable_endscan(scan);
58         heap_close(heapRel, AccessShareLock);
59
60         return opclassOid;
61 }
62
63 static Oid
64 getAMProc(Oid amoid, Oid typid) 
65 {
66         Oid             opclassOid = getDefaultOpclass(amoid, typid);
67         Oid             procOid = InvalidOid;
68         Oid             opfamilyOid;
69         ScanKeyData     skey[4];
70         SysScanDesc     scan;
71         HeapTuple       tuple;
72         Relation        heapRel;
73
74         if ( !OidIsValid(opclassOid) )
75         {
76                 typid = getBaseType(typid);
77                 opclassOid = getDefaultOpclass(amoid, typid);
78         }
79
80         if ( !OidIsValid(opclassOid) )
81         {
82                 CatCList        *catlist;
83                 int                     i;
84
85                 /*
86                  * Search binary-coercible type
87                  */
88                 catlist = SearchSysCacheList(CASTSOURCETARGET, 1,
89                                                                                 ObjectIdGetDatum(typid),
90                                                                                 0, 0, 0);
91
92                 for (i = 0; i < catlist->n_members; i++)
93                 {
94                         HeapTuple       tuple = &catlist->members[i]->tuple;
95                         Form_pg_cast    castForm = (Form_pg_cast)GETSTRUCT(tuple);
96
97                         if ( castForm->castfunc == InvalidOid && castForm->castcontext == COERCION_CODE_IMPLICIT )
98                         {
99                                 typid = castForm->casttarget;
100                                 opclassOid = getDefaultOpclass(amoid, typid);
101                                 if( OidIsValid(opclassOid) )
102                                         break;
103                         }
104                 }
105
106                 ReleaseSysCacheList(catlist);
107         }
108         
109         if ( !OidIsValid(opclassOid) )
110                 return InvalidOid;
111
112         opfamilyOid = get_opclass_family(opclassOid);
113
114         heapRel = heap_open(AccessMethodProcedureRelationId, AccessShareLock);
115         ScanKeyInit(&skey[0],
116                                 Anum_pg_amproc_amprocfamily,
117                                 BTEqualStrategyNumber, F_OIDEQ,
118                                 ObjectIdGetDatum(opfamilyOid));
119         ScanKeyInit(&skey[1],
120                                 Anum_pg_amproc_amproclefttype,
121                                 BTEqualStrategyNumber, F_OIDEQ,
122                                 ObjectIdGetDatum(typid));
123         ScanKeyInit(&skey[2],
124                                 Anum_pg_amproc_amprocrighttype,
125                                 BTEqualStrategyNumber, F_OIDEQ,
126                                 ObjectIdGetDatum(typid));
127 #if PG_VERSION_NUM >= 90200
128         ScanKeyInit(&skey[3],
129                                 Anum_pg_amproc_amprocnum,
130                                 BTEqualStrategyNumber, F_OIDEQ,
131                                 Int32GetDatum(BTORDER_PROC));
132 #endif
133
134         scan = systable_beginscan(heapRel, AccessMethodProcedureIndexId, true,
135                                                                 SnapshotNow, 
136 #if PG_VERSION_NUM >= 90200
137                                                                 4,
138 #else
139                                                                 3,
140 #endif
141                                                                 skey);
142         while (HeapTupleIsValid(tuple = systable_getnext(scan)))
143         {
144                 Form_pg_amproc amprocform = (Form_pg_amproc) GETSTRUCT(tuple);
145
146                 switch(amoid)
147                 {
148                         case BTREE_AM_OID:
149                         case HASH_AM_OID:
150                                 if ( OidIsValid(procOid) )
151                                         elog(ERROR,"Ambiguous support function for type %u (opclass %u)", typid, opfamilyOid);
152                                 procOid = amprocform->amproc;
153                                 break;
154                         default:
155                                 elog(ERROR,"Unsupported access method");
156                 }
157         }
158
159         systable_endscan(scan);
160         heap_close(heapRel, AccessShareLock);
161
162         return procOid;
163 }
164
165 static ProcTypeInfo *cacheProcs = NULL;
166 static int nCacheProcs = 0;
167
168 static ProcTypeInfo
169 fillProcs(Oid typid)
170 {
171         ProcTypeInfo    info = malloc(sizeof(ProcTypeInfoData));
172
173         if (!info)
174                 elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfoData));
175
176         info->typid = typid;
177         info->typtype = get_typtype(typid);
178
179         if (info->typtype == 'c')
180         {
181                 /* composite type */
182                 TupleDesc               tupdesc;
183                 MemoryContext   oldcontext;
184
185                 tupdesc = lookup_rowtype_tupdesc(typid, -1);
186
187                 if (tupdesc->natts != 2)
188                         elog(ERROR,"Composite type has wrong number of fields");
189                 if (tupdesc->attrs[1]->atttypid != FLOAT4OID)
190                         elog(ERROR,"Second field of composite type is not float4");
191
192                 oldcontext = MemoryContextSwitchTo(TopMemoryContext);
193                 info->tupDesc = CreateTupleDescCopyConstr(tupdesc);
194                 MemoryContextSwitchTo(oldcontext);
195
196                 ReleaseTupleDesc(tupdesc);
197
198                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, info->tupDesc->attrs[0]->atttypid);
199                 info->hashFuncOid = getAMProc(HASH_AM_OID, info->tupDesc->attrs[0]->atttypid);
200         } 
201         else 
202         {
203                 info->tupDesc = NULL;
204
205                 /* plain type */
206                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, typid);
207                 info->hashFuncOid = getAMProc(HASH_AM_OID, typid);
208         }
209
210         get_typlenbyvalalign(typid, &info->typlen, &info->typbyval, &info->typalign);
211         info->hashFuncInited = info->cmpFuncInited = false;
212
213
214         return info;
215 }
216
217 void
218 getFmgrInfoCmp(ProcTypeInfo info)
219 {
220         if ( info->cmpFuncInited == false )
221         {
222                 if ( !OidIsValid(info->cmpFuncOid) )
223                         elog(ERROR, "Could not find cmp function for type %u", info->typid);
224
225                 fmgr_info_cxt( info->cmpFuncOid, &info->cmpFunc, TopMemoryContext );
226                 info->cmpFuncInited = true;
227         }
228 }
229
230 void
231 getFmgrInfoHash(ProcTypeInfo info)
232 {
233         if ( info->hashFuncInited == false )
234         {
235                 if ( !OidIsValid(info->hashFuncOid) )
236                         elog(ERROR, "Could not find hash function for type %u", info->typid);
237
238                 fmgr_info_cxt( info->hashFuncOid, &info->hashFunc, TopMemoryContext );
239                 info->hashFuncInited = true;
240         }
241 }
242
243 static int
244 cmpProcTypeInfo(const void *a, const void *b)
245 {
246         ProcTypeInfo av = *(ProcTypeInfo*)a;
247         ProcTypeInfo bv = *(ProcTypeInfo*)b;
248
249         Assert( av->typid != bv->typid );
250
251         return ( av->typid > bv->typid ) ? 1 : -1; 
252 }
253
254 ProcTypeInfo
255 findProcs(Oid typid)
256 {
257         ProcTypeInfo    info = NULL;
258
259         if ( nCacheProcs == 1 )
260         {
261                 if ( cacheProcs[0]->typid == typid ) 
262                 {
263                         /*cacheProcs[0]->hashFuncInited = cacheProcs[0]->cmpFuncInited = false;*/
264                         return cacheProcs[0];
265                 }
266         }
267         else if ( nCacheProcs > 1 )
268         {
269                 ProcTypeInfo    *StopMiddle;
270                 ProcTypeInfo    *StopLow = cacheProcs, 
271                                                 *StopHigh = cacheProcs + nCacheProcs;
272
273                 while (StopLow < StopHigh) {
274                         StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
275                         info = *StopMiddle;
276
277                         if ( info->typid == typid )
278                         {
279                                 /* info->hashFuncInited = info->cmpFuncInited = false; */
280                                 return info;
281                         }
282                         else if ( info->typid < typid )
283                                 StopLow = StopMiddle + 1;
284                         else
285                                 StopHigh = StopMiddle;
286                 }
287
288                 /* not found */
289         } 
290
291         info = fillProcs(typid);
292         if ( nCacheProcs == 0 ) 
293         {
294                 cacheProcs = malloc(sizeof(ProcTypeInfo));
295
296                 if (!cacheProcs)
297                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo));
298                 else
299                 {
300                         nCacheProcs = 1;
301                         cacheProcs[0] = info;
302                 }
303         }
304         else
305         {
306                 ProcTypeInfo    *cacheProcsTmp = realloc(cacheProcs, (nCacheProcs+1) * sizeof(ProcTypeInfo));
307
308                 if (!cacheProcsTmp)
309                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo) * (nCacheProcs+1));
310                 else
311                 {
312                         cacheProcs = cacheProcsTmp;
313                         cacheProcs[ nCacheProcs ] = info;
314                         nCacheProcs++;
315                         qsort(cacheProcs, nCacheProcs, sizeof(ProcTypeInfo), cmpProcTypeInfo);
316                 }
317         }
318
319         /* info->hashFuncInited = info->cmpFuncInited = false; */
320
321         return info;
322 }
323
324 /*
325  * WARNING. Array2SimpleArray* doesn't copy Datum!
326  */
327 SimpleArray * 
328 Array2SimpleArray(ProcTypeInfo info, ArrayType *a)
329 {
330         SimpleArray     *s = palloc(sizeof(SimpleArray));
331
332         CHECKARRVALID(a);
333
334         if ( info == NULL )
335                 info = findProcs(ARR_ELEMTYPE(a));
336
337         s->info = info;
338         s->df = NULL;
339         s->hash = NULL;
340
341         deconstruct_array(a, info->typid,
342                                                 info->typlen, info->typbyval, info->typalign,
343                                                 &s->elems, NULL, &s->nelems);
344
345         return s;
346 }
347
348 static Datum
349 deconstructCompositeType(ProcTypeInfo info, Datum in, double *weight)
350 {
351         HeapTupleHeader rec = DatumGetHeapTupleHeader(in);
352         HeapTupleData   tuple;
353         Datum                   values[2];
354         bool                    nulls[2];
355
356         /* Build a temporary HeapTuple control structure */
357         tuple.t_len = HeapTupleHeaderGetDatumLength(rec);
358         ItemPointerSetInvalid(&(tuple.t_self));
359         tuple.t_tableOid = InvalidOid;
360         tuple.t_data = rec;
361
362         heap_deform_tuple(&tuple, info->tupDesc, values, nulls);
363         if (nulls[0] || nulls[1])
364                 elog(ERROR, "Both fields in composite type could not be NULL");
365
366         if (weight)
367                 *weight = DatumGetFloat4(values[1]);
368         return values[0];
369 }
370
371 static int
372 cmpArrayElem(const void *a, const void *b, void *arg)
373 {
374         ProcTypeInfo    info = (ProcTypeInfo)arg;
375
376         if (info->tupDesc) 
377                 /* composite type */
378                 return DatumGetInt32( FCall2( &info->cmpFunc,
379                                                                                         deconstructCompositeType(info, *(Datum*)a, NULL),
380                                                                                         deconstructCompositeType(info, *(Datum*)b, NULL) ) );
381
382         return DatumGetInt32( FCall2( &info->cmpFunc,
383                                                         *(Datum*)a, *(Datum*)b ) );
384 }
385
386 SimpleArray *
387 Array2SimpleArrayS(ProcTypeInfo info, ArrayType *a)
388 {
389         SimpleArray     *s = Array2SimpleArray(info, a);
390
391         if ( s->nelems > 1 )
392         {
393                 getFmgrInfoCmp(s->info);
394
395                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElem, s->info);
396         }
397
398         return s;
399 }
400
401 typedef struct cmpArrayElemData {
402         ProcTypeInfo    info;
403         bool                    hasDuplicate;
404
405 } cmpArrayElemData;
406
407 static int
408 cmpArrayElemArg(const void *a, const void *b, void *arg)
409 {
410     cmpArrayElemData    *data = (cmpArrayElemData*)arg;
411         int res;
412
413         if (data->info->tupDesc)
414                 res =  DatumGetInt32( FCall2( &data->info->cmpFunc,
415                                                                                         deconstructCompositeType(data->info, *(Datum*)a, NULL),
416                                                                                         deconstructCompositeType(data->info, *(Datum*)b, NULL) ) );
417         else
418                 res = DatumGetInt32( FCall2( &data->info->cmpFunc,
419                                                                 *(Datum*)a, *(Datum*)b ) );
420
421         if ( res == 0 )
422                 data->hasDuplicate = true;
423
424         return res;
425 }
426
427 /*
428  * Uniquefy array and calculate TF. Although 
429  * result doesn't depend on normalization, we
430  * normalize TF by length array to have possiblity
431  * to limit estimation for index support.
432  *
433  * Cache signals of needing of TF caclulation
434  */
435
436 SimpleArray *
437 Array2SimpleArrayU(ProcTypeInfo info, ArrayType *a, void *cache)
438 {
439         SimpleArray     *s = Array2SimpleArray(info, a);
440         StatElem        *stat = NULL;
441
442         if ( s->nelems > 0 && cache )
443         {
444                 s->df = palloc(sizeof(double) * s->nelems);
445                 s->df[0] = 1.0; /* init */
446         }
447
448         if ( s->nelems > 1 )
449         {
450                 cmpArrayElemData        data;
451                 int                             i;
452
453                 getFmgrInfoCmp(s->info);
454                 data.info = s->info;
455                 data.hasDuplicate = false;
456
457                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElemArg, &data);
458
459                 if ( data.hasDuplicate )
460                 {
461                         Datum   *tmp,
462                                         *dr,
463                                         *data;
464                         int             num = s->nelems,
465                                         cmp;
466
467                         data = tmp = dr = s->elems;
468
469                         while (tmp - data < num)
470                         {
471                                 cmp = (tmp == dr) ? 0 : cmpArrayElem(tmp, dr, s->info);
472                                 if ( cmp != 0 )
473                                 {
474                                         *(++dr) = *tmp++;
475                                         if ( cache ) 
476                                                 s->df[ dr - data ] = 1.0;
477                                 }
478                                 else
479                                 {
480                                         if ( cache )
481                                                 s->df[ dr - data ] += 1.0;
482                                         tmp++;
483                                 }
484                         }
485
486                         s->nelems = dr + 1 - s->elems;
487
488                         if ( cache )
489                         {
490                                 int tfm = getTFMethod();
491
492                                 for(i=0;i<s->nelems;i++)
493                                 {
494                                         stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
495                                         if ( stat )
496                                         {
497                                                 switch(tfm)
498                                                 {
499                                                         case TF_LOG:
500                                                                 s->df[i] = (1.0 + log( s->df[i] ));
501                                                         case TF_N:
502                                                                 s->df[i] *= stat->idf;
503                                                                 break;
504                                                         case TF_CONST:
505                                                                 s->df[i] = stat->idf;
506                                                                 break;
507                                                         default:
508                                                                 elog(ERROR,"Unknown TF method: %d", tfm);
509                                                 }       
510                                         }
511                                         else
512                                         {
513                                                 s->df[i] = 0.0; /* unknown word */
514                                         }
515                                 }       
516                         }
517                 }
518                 else if ( cache )
519                 {
520                         for(i=0;i<s->nelems;i++)
521                         {
522                                 stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
523                                 if ( stat )
524                                         s->df[i] = stat->idf;
525                                 else
526                                         s->df[i] = 0.0;
527                         }
528                 }
529         }
530         else if (s->nelems > 0 && cache) 
531         {
532                 stat = fingArrayStat(cache, s->info->typid, s->elems[0], stat);
533                 if ( stat )
534                         s->df[0] = stat->idf;
535                 else
536                         s->df[0] = 0.0;
537         }
538
539         return s;
540 }
541
542 static int
543 numOfIntersect(SimpleArray *a, SimpleArray *b)
544 {
545         int                             cnt = 0,
546                                         cmp;
547         Datum                   *aptr = a->elems,
548                                         *bptr = b->elems;
549         ProcTypeInfo    info = a->info;
550
551         Assert( a->info->typid == b->info->typid );
552
553         getFmgrInfoCmp(info);
554
555         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
556         {
557                 cmp = cmpArrayElem(aptr, bptr, info);
558                 if ( cmp < 0 )
559                         aptr++;
560                 else if ( cmp > 0 )
561                         bptr++;
562                 else
563                 {
564                         cnt++;
565                         aptr++;
566                         bptr++;
567                 }
568         }
569
570         return cnt;
571 }
572
573 static double
574 TFIDFSml(SimpleArray *a, SimpleArray *b)
575 {
576         int                             cmp;
577         Datum                   *aptr = a->elems,
578                                         *bptr = b->elems;
579         ProcTypeInfo    info = a->info;
580         double                  res = 0.0;
581         double                  suma = 0.0, sumb = 0.0;
582
583         Assert( a->info->typid == b->info->typid );
584         Assert( a->df );
585         Assert( b->df );
586
587         getFmgrInfoCmp(info);
588
589         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
590         {
591                 cmp = cmpArrayElem(aptr, bptr, info);
592                 if ( cmp < 0 )
593                 {
594                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
595                         aptr++;
596                 }
597                 else if ( cmp > 0 )
598                 {
599                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
600                         bptr++;
601                 }
602                 else
603                 {
604                         res += a->df[ aptr - a->elems ] * b->df[ bptr - b->elems ];
605                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
606                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
607                         aptr++;
608                         bptr++;
609                 }
610         }
611
612         /*
613          * Compute last elements
614          */
615         while( aptr - a->elems < a->nelems )
616         {
617                 suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
618                 aptr++;
619         }
620
621         while( bptr - b->elems < b->nelems )
622         {
623                 sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
624                 bptr++;
625         }
626
627         if ( suma > 0.0 && sumb > 0.0 )
628                 res = res / sqrt( suma * sumb );
629         else
630                 res = 0.0;
631
632         return res;
633 }
634
635
636 PG_FUNCTION_INFO_V1(arraysml);
637 Datum   arraysml(PG_FUNCTION_ARGS);
638 Datum
639 arraysml(PG_FUNCTION_ARGS)
640 {
641         ArrayType               *a, *b;
642         SimpleArray             *sa, *sb;
643         
644         fcinfo->flinfo->fn_extra = SearchArrayCache( 
645                                                         fcinfo->flinfo->fn_extra,
646                                                         fcinfo->flinfo->fn_mcxt,
647                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
648         fcinfo->flinfo->fn_extra = SearchArrayCache( 
649                                                         fcinfo->flinfo->fn_extra,
650                                                         fcinfo->flinfo->fn_mcxt,
651                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
652
653         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
654                 elog(ERROR,"Arguments array are not the same type!");
655
656         if (ARRISVOID(a) || ARRISVOID(b))
657                  PG_RETURN_FLOAT4(0.0);
658
659         switch(getSmlType())
660         {
661                 case    ST_TFIDF:
662                         PG_RETURN_FLOAT4( TFIDFSml(sa, sb) );
663                         break;
664                 case    ST_COSINE:
665                         {
666                                 int                             cnt;
667                                 double                  power;
668
669                                 power = ((double)(sa->nelems)) * ((double)(sb->nelems));
670                                 cnt = numOfIntersect(sa, sb);
671
672                                 PG_RETURN_FLOAT4(  ((double)cnt) / sqrt( power ) );
673                         }
674                         break;
675                 case    ST_OVERLAP:
676                         {
677                                 float4 res = (float4)numOfIntersect(sa, sb);
678
679                                 PG_RETURN_FLOAT4(res);
680                         }
681                         break;
682                 default:
683                         elog(ERROR,"Unsupported formula type of similarity");
684         }
685
686         PG_RETURN_FLOAT4(0.0); /* keep compiler quiet */ 
687 }
688
689 PG_FUNCTION_INFO_V1(arraysmlw);
690 Datum   arraysmlw(PG_FUNCTION_ARGS);
691 Datum
692 arraysmlw(PG_FUNCTION_ARGS)
693 {
694         ArrayType               *a, *b;
695         SimpleArray             *sa, *sb;
696         bool                    useIntersect = PG_GETARG_BOOL(2);
697         double                  numerator = 0.0;
698         double                  denominatorA = 0.0,
699                                         denominatorB = 0.0,
700                                         tmpA, tmpB;
701         int                             cmp;
702         ProcTypeInfo    info;
703         int                             ai = 0, bi = 0;
704
705         
706         fcinfo->flinfo->fn_extra = SearchArrayCache( 
707                                                         fcinfo->flinfo->fn_extra,
708                                                         fcinfo->flinfo->fn_mcxt,
709                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
710         fcinfo->flinfo->fn_extra = SearchArrayCache( 
711                                                         fcinfo->flinfo->fn_extra,
712                                                         fcinfo->flinfo->fn_mcxt,
713                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
714
715         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
716                 elog(ERROR,"Arguments array are not the same type!");
717
718         if (ARRISVOID(a) || ARRISVOID(b))
719                  PG_RETURN_FLOAT4(0.0);
720
721         info = sa->info;
722         if (info->tupDesc == NULL)
723                 elog(ERROR, "Only weigthed (composite) types should be used");
724         getFmgrInfoCmp(info);
725
726         while(ai < sa->nelems && bi < sb->nelems)
727         {
728                 Datum   ad = deconstructCompositeType(info, sa->elems[ai], &tmpA),
729                                 bd = deconstructCompositeType(info, sb->elems[bi], &tmpB);
730
731                 cmp = DatumGetInt32(FCall2(&info->cmpFunc, ad, bd));
732
733                 if ( cmp < 0 ) {
734                         if (useIntersect == false)
735                                 denominatorA += tmpA * tmpA;
736                         ai++;
737                 } else if ( cmp > 0 ) {
738                         if (useIntersect == false)
739                                 denominatorB += tmpB * tmpB;
740                         bi++;
741                 } else {
742                         denominatorA += tmpA * tmpA;
743                         denominatorB += tmpB * tmpB;
744                         numerator += tmpA * tmpB;
745                         ai++;
746                         bi++;
747                 }
748         }
749
750         if (useIntersect == false) {
751                 while(ai < sa->nelems) {
752                         deconstructCompositeType(info, sa->elems[ai], &tmpA);
753                         denominatorA += tmpA * tmpA;
754                         ai++;
755                 }
756
757                 while(bi < sb->nelems) {
758                         deconstructCompositeType(info, sb->elems[bi], &tmpB);
759                         denominatorB += tmpB * tmpB;
760                         bi++;
761                 }
762         }
763
764         if (numerator != 0.0) {
765                 numerator = numerator / sqrt( denominatorA * denominatorB );
766         }
767
768         PG_RETURN_FLOAT4(numerator);
769 }
770
771 PG_FUNCTION_INFO_V1(arraysml_op);
772 Datum   arraysml_op(PG_FUNCTION_ARGS);
773 Datum
774 arraysml_op(PG_FUNCTION_ARGS)
775 {
776         ArrayType               *a, *b;
777         SimpleArray             *sa, *sb;
778         double                  power = 0.0;
779         
780         fcinfo->flinfo->fn_extra = SearchArrayCache( 
781                                                         fcinfo->flinfo->fn_extra,
782                                                         fcinfo->flinfo->fn_mcxt,
783                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
784         fcinfo->flinfo->fn_extra = SearchArrayCache( 
785                                                         fcinfo->flinfo->fn_extra,
786                                                         fcinfo->flinfo->fn_mcxt,
787                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
788
789         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
790                 elog(ERROR,"Arguments array are not the same type!");
791
792         if (ARRISVOID(a) || ARRISVOID(b))
793                  PG_RETURN_BOOL(false);
794
795         switch(getSmlType())
796         {
797                 case    ST_TFIDF:
798                         power = TFIDFSml(sa, sb);
799                         break;
800                 case    ST_COSINE:
801                         {
802                                 int                             cnt;
803
804                                 power = sqrt( ((double)(sa->nelems)) * ((double)(sb->nelems)) );
805
806                                 if (  ((double)Min(sa->nelems, sb->nelems)) / power < GetSmlarLimit()  )
807                                         PG_RETURN_BOOL(false);
808
809                                 cnt = numOfIntersect(sa, sb);
810                                 power = ((double)cnt) / power;
811                         }
812                         break;
813                 case    ST_OVERLAP:
814                         power = (double)numOfIntersect(sa, sb);
815                         break;
816                 default:
817                         elog(ERROR,"Unsupported formula type of similarity");
818         }
819
820         PG_RETURN_BOOL(power >= GetSmlarLimit());
821 }
822
823 #define QBSIZE          8192
824 static char cachedFormula[QBSIZE];
825 static int      cachedLen  = 0;
826 static void     *cachedPlan = NULL;
827
828 PG_FUNCTION_INFO_V1(arraysml_func);
829 Datum   arraysml_func(PG_FUNCTION_ARGS);
830 Datum
831 arraysml_func(PG_FUNCTION_ARGS)
832 {
833         ArrayType               *a, *b;
834         SimpleArray             *sa, *sb;
835         int                             cnt;
836         float4                  result = -1.0;
837         Oid                             arg[] = {INT4OID, INT4OID, INT4OID};
838         Datum                   pars[3];
839         bool                    isnull;
840         void                    *plan;
841         int                             stat;
842         text                    *formula = PG_GETARG_TEXT_P(2); 
843         
844         fcinfo->flinfo->fn_extra = SearchArrayCache( 
845                                                         fcinfo->flinfo->fn_extra,
846                                                         fcinfo->flinfo->fn_mcxt,
847                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
848         fcinfo->flinfo->fn_extra = SearchArrayCache( 
849                                                         fcinfo->flinfo->fn_extra,
850                                                         fcinfo->flinfo->fn_mcxt,
851                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
852
853         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
854                 elog(ERROR,"Arguments array are not the same type!");
855
856         if (ARRISVOID(a) || ARRISVOID(b))
857                  PG_RETURN_BOOL(false);
858
859         cnt = numOfIntersect(sa, sb);
860
861         if ( VARSIZE(formula) - VARHDRSZ > QBSIZE - 1024 )
862                 elog(ERROR,"Formula is too long");
863
864         SPI_connect();
865
866         
867         if ( cachedPlan == NULL || cachedLen != VARSIZE(formula) - VARHDRSZ ||
868                                 memcmp( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ) != 0 )
869         {
870                 char                    *ptr, buf[QBSIZE];
871
872                 *cachedFormula = '\0';
873                 if ( cachedPlan )
874                         SPI_freeplan(cachedPlan);
875                 cachedPlan = NULL;
876                 cachedLen = 0;
877
878                 ptr = stpcpy( buf, "SELECT (" );
879                 memcpy( ptr, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
880                 ptr += VARSIZE(formula) - VARHDRSZ;
881                 ptr = stpcpy( ptr, ")::float4 FROM");
882                 ptr = stpcpy( ptr, " (SELECT $1 ::float8 AS i, $2 ::float8 AS a, $3 ::float8 AS b) AS N;");
883                 *ptr = '\0';
884
885                 plan = SPI_prepare(buf, 3, arg);
886                 if (!plan)
887                         elog(ERROR, "SPI_prepare() failed");
888
889                 cachedPlan = SPI_saveplan(plan);
890                 if (!cachedPlan)
891                         elog(ERROR, "SPI_saveplan() failed");
892
893                 SPI_freeplan(plan);
894                 cachedLen = VARSIZE(formula) - VARHDRSZ;
895                 memcpy( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ); 
896         }
897
898         plan = cachedPlan;
899
900
901         pars[0] = Int32GetDatum( cnt );
902         pars[1] = Int32GetDatum( sa->nelems );
903         pars[2] = Int32GetDatum( sb->nelems );
904
905         stat = SPI_execute_plan(plan, pars, NULL, true, 3);
906         if (stat < 0)
907                 elog(ERROR, "SPI_execute_plan() returns %d", stat);
908
909         if ( SPI_processed > 0)
910                 result = DatumGetFloat4(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
911
912         SPI_finish();
913
914         PG_RETURN_FLOAT4(result);
915 }
916
917 PG_FUNCTION_INFO_V1(array_unique);
918 Datum   array_unique(PG_FUNCTION_ARGS);
919 Datum
920 array_unique(PG_FUNCTION_ARGS)
921 {
922         ArrayType               *a = PG_GETARG_ARRAYTYPE_P(0);
923         ArrayType               *res;
924         SimpleArray             *sa;
925
926         sa = Array2SimpleArrayU(NULL, a, NULL);
927
928         res = construct_array(  sa->elems, 
929                                                         sa->nelems,
930                                                         sa->info->typid,
931                                                         sa->info->typlen,
932                                                         sa->info->typbyval,
933                                                         sa->info->typalign);
934
935         pfree(sa->elems);
936         pfree(sa);
937         PG_FREE_IF_COPY(a, 0);
938
939         PG_RETURN_ARRAYTYPE_P(res);
940 }
941
942 PG_FUNCTION_INFO_V1(inarray);
943 Datum   inarray(PG_FUNCTION_ARGS);
944 Datum
945 inarray(PG_FUNCTION_ARGS)
946 {
947         ArrayType               *a;
948         SimpleArray             *sa;
949         Datum                   query = PG_GETARG_DATUM(1); 
950         Oid                             queryTypeOid;
951         Datum                   *StopLow,
952                                         *StopHigh,
953                                         *StopMiddle;
954         int                             cmp;
955
956         fcinfo->flinfo->fn_extra = SearchArrayCache( 
957                                                         fcinfo->flinfo->fn_extra,
958                                                         fcinfo->flinfo->fn_mcxt,
959                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
960
961         queryTypeOid = get_fn_expr_argtype(fcinfo->flinfo, 1);
962
963         if ( queryTypeOid == InvalidOid )
964                 elog(ERROR,"inarray: could not determine actual argument type");
965
966         if ( queryTypeOid != sa->info->typid )
967                 elog(ERROR,"inarray: Type of array's element and type of argument are not the same");
968
969         getFmgrInfoCmp(sa->info);
970         StopLow = sa->elems;
971         StopHigh = sa->elems + sa->nelems;
972
973         while (StopLow < StopHigh) 
974         {
975                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
976                 cmp = cmpArrayElem(StopMiddle, &query, sa->info);
977
978                 if ( cmp == 0 )
979                 {
980                         /* found */
981                         if ( PG_NARGS() >= 3 )
982                                 PG_RETURN_DATUM(PG_GETARG_DATUM(2));
983                         PG_RETURN_FLOAT4(1.0);
984                 }
985                 else if (cmp < 0)
986                         StopLow = StopMiddle + 1;
987                 else
988                         StopHigh = StopMiddle;
989         }
990
991         if ( PG_NARGS() >= 4 )
992                 PG_RETURN_DATUM(PG_GETARG_DATUM(3));
993         PG_RETURN_FLOAT4(0.0);
994 }