support pgsql versions up to 11
[smlar.git] / smlar_stat.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "catalog/pg_type.h"
5 #include "executor/spi.h"
6 #include "utils/array.h"
7 #include "utils/datum.h"
8 #include "utils/memutils.h"
9
10 static StatCache *PersistentDocStat = NULL;
11
12 static void*
13 cacheAlloc(MemoryContext ctx, size_t size)
14 {
15         if ( GetSmlarUsePersistent() )
16         {
17                 void *ptr = malloc(size);
18
19                 if (!ptr)
20                         ereport(ERROR,
21                                         (errcode(ERRCODE_OUT_OF_MEMORY),
22                                         errmsg("out of memory")));
23
24                 return ptr;
25         }
26
27         return  MemoryContextAlloc(ctx, size);
28 }
29
30 static void*
31 cacheAllocZero(MemoryContext ctx, size_t size)
32 {
33         void *ptr = cacheAlloc(ctx, size);
34         memset(ptr, 0, size);
35         return ptr;
36 }
37
38 struct StatCache *
39 initStatCache(MemoryContext ctx)
40 {
41         if (PersistentDocStat && GetSmlarUsePersistent())
42                 return PersistentDocStat;
43         else {
44                 int                     stat;
45                 char            buf[1024];
46                 const char      *tbl = GetSmlarTable();
47                 StatCache       *cache = NULL;
48
49                 if ( tbl == NULL || *tbl == '\0' )
50                         elog(ERROR,"smlar.stattable is not defined");
51
52                 sprintf(buf,"SELECT * FROM \"%s\" ORDER BY 1;", tbl);
53                 SPI_connect();
54                 stat = SPI_execute(buf, true, 0);
55
56                 if (stat != SPI_OK_SELECT)
57                         elog(ERROR, "SPI_execute() returns %d", stat);
58
59                 if ( SPI_processed == 0 )
60                 {
61                         elog(ERROR, "Stat table '%s' is empty", tbl);
62                 }
63                 else
64                 {
65                         int             i;
66                         double  totaldocs = 0.0;
67                         Oid             ndocType = SPI_gettypeid(SPI_tuptable->tupdesc, 2);
68
69                         if ( SPI_tuptable->tupdesc->natts != 2 )
70                                 elog(ERROR,"Stat table is not (type, int4)");
71                         if ( !(ndocType == INT4OID || ndocType == INT8OID) )
72                                 elog(ERROR,"Stat table is not (type, int4) nor (type, int8)");
73
74                         cache = cacheAllocZero(ctx, sizeof(StatCache));
75                         cache->info = findProcs( SPI_gettypeid(SPI_tuptable->tupdesc, 1) );
76                         if (cache->info->tupDesc)
77                                 elog(ERROR, "TF/IDF is not supported for composite (weighted) type");
78                         getFmgrInfoCmp(cache->info);
79                         cache->elems = cacheAlloc(ctx, sizeof(StatElem) * SPI_processed);
80
81                         for(i=0; i<SPI_processed; i++)
82                         {
83                                 bool    isnullvalue, isnullndoc;
84                                 Datum   datum = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnullvalue);
85                                 int64   ndoc;
86
87                                 if (ndocType == INT4OID)
88                                         ndoc = DatumGetInt32(SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 2, &isnullndoc));
89                                 else
90                                         ndoc = DatumGetInt64(SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 2, &isnullndoc));
91
92                                 if (isnullndoc)
93                                         elog(ERROR,"NULL value in second column of table '%s'", tbl);
94
95                                 if (isnullvalue)
96                                 {
97                                         /* total number of docs */
98
99                                         if (ndoc <= 0)
100                                                 elog(ERROR,"Total number of document should be positive");
101                                         if ( totaldocs > 0 )
102                                                 elog(ERROR,"Total number of document is repeated");
103                                         totaldocs = ndoc;
104                                 }
105                                 else
106                                 {
107                                         if ( i>0 && DatumGetInt32( FCall2( &cache->info->cmpFunc, cache->elems[i-1].datum, datum ) ) == 0 )
108                                                 elog(ERROR,"Values of first column of table '%s' are not unique", tbl);
109
110                                         if (ndoc <= 0)
111                                                 elog(ERROR,"Number of documents with current value should be positive");
112
113                                         if ( cache->info->typbyval )
114                                                 cache->elems[i].datum = datum;
115                                         else
116                                         {
117                                                 size_t  size = datumGetSize(datum, false, cache->info->typlen);
118
119                                                 cache->elems[i].datum = PointerGetDatum(cacheAlloc(ctx, size));
120                                                 memcpy(DatumGetPointer(cache->elems[i].datum), DatumGetPointer(datum), size);
121                                         }
122
123                                         cache->elems[i].idf = ndoc;
124                                 }
125                         }
126
127                         if ( totaldocs <= 0)
128                                 elog(ERROR,"Total number of document is unknown");
129                         cache->nelems = SPI_processed - 1;
130
131                         for(i=0;i<cache->nelems;i++)
132                         {
133                                 if ( totaldocs < cache->elems[i].idf )
134                                         elog(ERROR,"Inconsitent data in '%s': there is values with frequency > 1", tbl);
135                                 cache->elems[i].idf = log( totaldocs / cache->elems[i].idf + getOneAdd() );
136                         }
137                 }
138
139                 SPI_finish();
140
141                 if ( GetSmlarUsePersistent() )
142                         PersistentDocStat = cache;
143
144                 return cache;
145         }
146 }
147
148 void
149 resetStatCache(void)
150 {
151         if ( PersistentDocStat )
152         {
153
154                 if (!PersistentDocStat->info->typbyval)
155                 {
156                         int i;
157                         for(i=0;i<PersistentDocStat->nelems;i++)
158                                 free( DatumGetPointer(PersistentDocStat->elems[i].datum) );
159                 }
160
161                 if (PersistentDocStat->helems)
162                         free(PersistentDocStat->helems);
163                 free(PersistentDocStat->elems);
164                 free(PersistentDocStat);
165         }
166
167         PersistentDocStat = NULL;
168 }
169
170 StatElem  *
171 findStat(StatCache *stat, Datum query, StatElem *low)
172 {
173         StatElem        *StopLow = (low) ? low : stat->elems,
174                                 *StopHigh = stat->elems + stat->nelems,
175                                 *StopMiddle;
176         int                     cmp;
177
178         if (stat->info->tupDesc)
179                 elog(ERROR, "TF/IDF is not supported for composite (weighted) type");
180
181         while (StopLow < StopHigh)
182         {
183                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
184                 cmp = DatumGetInt32( FCall2( &stat->info->cmpFunc, StopMiddle->datum, query ) );
185
186                 if ( cmp == 0 )
187                         return StopMiddle;
188                 else if (cmp < 0)
189                         StopLow = StopMiddle + 1;
190                 else
191                         StopHigh = StopMiddle;
192         }
193
194         return NULL;
195 }
196
197 void
198 getHashStatCache(StatCache *stat, MemoryContext ctx, size_t n)
199 {
200         if ( !stat->helems )
201         {
202                 stat->helems = cacheAlloc(ctx, (stat->nelems +1) * sizeof(HashedElem));
203                 stat->selems = cacheAllocZero(ctx, n * sizeof(SignedElem));
204                 stat->nhelems = -1;
205         }
206 }